diff --git a/samples/agent/adk/contact_lookup/agent.py b/samples/agent/adk/contact_lookup/agent.py index aa8f0367d..d40bd423d 100644 --- a/samples/agent/adk/contact_lookup/agent.py +++ b/samples/agent/adk/contact_lookup/agent.py @@ -15,6 +15,7 @@ import json import logging import os +from collections import OrderedDict from collections.abc import AsyncIterable from dataclasses import dataclass from typing import Any, Dict, Optional @@ -65,7 +66,8 @@ def __init__(self, base_url: str): self._schema_managers: Dict[str, A2uiSchemaManager] = {} self._ui_runners: Dict[str, Runner] = {} - self._parsers: Dict[str, A2uiStreamParser] = {} + self._parsers: OrderedDict[str, A2uiStreamParser] = OrderedDict() + self._max_parsers = 1000 # Max active sessions to keep in memory for version in [VERSION_0_8, VERSION_0_9]: schema_manager = self._build_schema_manager(version) @@ -259,8 +261,12 @@ async def token_stream(): if selected_catalog: from a2ui.core.parser.streaming import A2uiStreamParser - if session_id not in self._parsers: + if session_id in self._parsers: + self._parsers.move_to_end(session_id) + else: self._parsers[session_id] = A2uiStreamParser(catalog=selected_catalog) + if len(self._parsers) > self._max_parsers: + self._parsers.popitem(last=False) async for part in stream_response_to_parts( self._parsers[session_id], diff --git a/samples/agent/adk/custom-components-example/agent.py b/samples/agent/adk/custom-components-example/agent.py index 00428a5f6..f160e503d 100644 --- a/samples/agent/adk/custom-components-example/agent.py +++ b/samples/agent/adk/custom-components-example/agent.py @@ -15,6 +15,7 @@ import json import logging import os +from collections import OrderedDict from collections.abc import AsyncIterable from dataclasses import dataclass from typing import Any, Dict, Optional @@ -22,6 +23,8 @@ import jsonschema from a2ui_examples import load_floor_plan_example +from a2ui.core.parser.streaming import A2uiStreamParser +from google.adk.agents import run_config from google.adk.agents.llm_agent import LlmAgent from google.adk.artifacts import InMemoryArtifactService from google.adk.memory.in_memory_memory_service import InMemoryMemoryService @@ -45,7 +48,7 @@ from a2ui.core.schema.manager import A2uiSchemaManager from a2ui.core.parser.parser import parse_response, ResponsePart from a2ui.basic_catalog.provider import BasicCatalog -from a2ui.a2a import create_a2ui_part, get_a2ui_agent_extension, parse_response_to_parts +from a2ui.a2a import create_a2ui_part, get_a2ui_agent_extension, parse_response_to_parts, stream_response_to_parts logger = logging.getLogger(__name__) @@ -63,6 +66,8 @@ def __init__(self, base_url: str): self._schema_managers: Dict[str, A2uiSchemaManager] = {} self._ui_runners: Dict[str, Runner] = {} + self._parsers: OrderedDict[str, A2uiStreamParser] = OrderedDict() + self._max_parsers = 1000 # Max active sessions to keep in memory for version in [VERSION_0_8, VERSION_0_9]: schema_manager = self._build_schema_manager(version) @@ -382,27 +387,48 @@ async def stream( current_message = types.Content( role="user", parts=[types.Part.from_text(text=current_query_text)] ) - final_response_content = None - async for event in runner.run_async( - user_id=self._user_id, - session_id=session.id, - new_message=current_message, - ): - logger.info(f"Event from runner: {event}") - if event.is_final_response(): - if event.content and event.content.parts and event.content.parts[0].text: - final_response_content = "\n".join( - [p.text for p in event.content.parts if p.text] - ) - break # Got the final response, stop consuming events + full_content_list = [] + + async def token_stream(): + async for event in runner.run_async( + user_id=self._user_id, + session_id=session.id, + run_config=run_config.RunConfig( + streaming_mode=run_config.StreamingMode.SSE + ), + new_message=current_message, + ): + if event.content and event.content.parts: + for p in event.content.parts: + if p.text: + full_content_list.append(p.text) + yield p.text + + if selected_catalog: + if session_id in self._parsers: + self._parsers.move_to_end(session_id) else: - logger.info(f"Intermediate event: {event}") - # Yield intermediate updates on every attempt + self._parsers[session_id] = A2uiStreamParser(catalog=selected_catalog) + if len(self._parsers) > self._max_parsers: + self._parsers.popitem(last=False) + + async for part in stream_response_to_parts( + self._parsers[session_id], + token_stream(), + ): yield { "is_task_complete": False, - "updates": self.get_processing_message(), + "parts": [part], } + else: + async for token in token_stream(): + yield { + "is_task_complete": False, + "updates": token, + } + + final_response_content = "".join(full_content_list) if full_content_list else None if final_response_content is None: logger.warning( diff --git a/samples/agent/adk/custom-components-example/agent_executor.py b/samples/agent/adk/custom-components-example/agent_executor.py index ad249239c..7a5f82de6 100644 --- a/samples/agent/adk/custom-components-example/agent_executor.py +++ b/samples/agent/adk/custom-components-example/agent_executor.py @@ -178,10 +178,14 @@ async def execute( ): is_task_complete = item["is_task_complete"] if not is_task_complete: - await updater.update_status( - TaskState.working, - new_agent_text_message(item["updates"], task.context_id, task.id), - ) + message = None + if "parts" in item: + message = new_agent_parts_message(item["parts"], task.context_id, task.id) + elif "updates" in item: + message = new_agent_text_message(item["updates"], task.context_id, task.id) + + if message: + await updater.update_status(TaskState.working, message) continue final_state = TaskState.input_required # Default diff --git a/samples/agent/adk/restaurant_finder/agent.py b/samples/agent/adk/restaurant_finder/agent.py index 1a9b70e57..a3f799246 100644 --- a/samples/agent/adk/restaurant_finder/agent.py +++ b/samples/agent/adk/restaurant_finder/agent.py @@ -15,6 +15,7 @@ import json import logging import os +from collections import OrderedDict from collections.abc import AsyncIterable from typing import Any, Optional, Dict @@ -67,7 +68,8 @@ def __init__(self, base_url: str): self._schema_managers: Dict[str, A2uiSchemaManager] = {} self._ui_runners: Dict[str, Runner] = {} - self._parsers = {} + self._parsers = OrderedDict() + self._max_parsers = 1000 # Max active sessions to keep in memory for version in [VERSION_0_8, VERSION_0_9]: schema_manager = self._build_schema_manager(version) @@ -255,8 +257,12 @@ async def token_stream(): if selected_catalog: from a2ui.core.parser.streaming import A2uiStreamParser - if session_id not in self._parsers: + if session_id in self._parsers: + self._parsers.move_to_end(session_id) + else: self._parsers[session_id] = A2uiStreamParser(catalog=selected_catalog) + if len(self._parsers) > self._max_parsers: + self._parsers.popitem(last=False) async for part in stream_response_to_parts( self._parsers[session_id], diff --git a/samples/client/lit/custom-components-example/client.ts b/samples/client/lit/custom-components-example/client.ts index 8e776caa3..3228e4969 100644 --- a/samples/client/lit/custom-components-example/client.ts +++ b/samples/client/lit/custom-components-example/client.ts @@ -39,7 +39,8 @@ export class A2UIClient { } async send( - message: v0_8.Types.A2UIClientEventMessage + message: v0_8.Types.A2UIClientEventMessage, + onChunk?: (messages: v0_8.Types.ServerToClientMessage[]) => void ): Promise { const catalog = componentRegistry.getInlineCatalog(); const finalMessage = { @@ -56,22 +57,90 @@ export class A2UIClient { method: "POST", }); - if (response.ok) { - const data = (await response.json()) as A2AServerPayload; - const messages: v0_8.Types.ServerToClientMessage[] = []; - if ("error" in data) { - throw new Error(data.error); - } else { - for (const item of data) { - if (item.kind === "text") continue; - messages.push(item.data); + if (!response.ok) { + const error = (await response.json()) as { error: string }; + throw new Error(error.error); + } + + const contentType = response.headers.get("content-type"); + const messages: v0_8.Types.ServerToClientMessage[] = []; + + if (contentType?.includes("text/event-stream")) { + const reader = response.body?.getReader(); + if (!reader) throw new Error("No response body"); + const decoder = new TextDecoder(); + let buffer = ""; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split("\n\n"); + buffer = lines.pop() || ""; + + for (const line of lines) { + if (line.startsWith("data: ")) { + const jsonStr = line.replace(/^data:\s*/, ""); + try { + const parsed = JSON.parse(jsonStr); + if ("error" in parsed) { + throw new Error(parsed.error); + } else { + const chunkMessages = this.#extractMessages(parsed); + if (chunkMessages.length > 0) { + messages.push(...chunkMessages); + onChunk?.(chunkMessages); + } + } + } catch (e) { + console.error("Error parsing SSE data:", e, jsonStr); + } + } } } return messages; } - const error = (await response.json()) as { error: string }; - throw new Error(error.error); + const data = (await response.json()) as any; + if (data && typeof data === 'object' && "error" in data) { + throw new Error(data.error); + } else { + const extracted = this.#extractMessages(data); + messages.push(...extracted); + if (messages.length > 0) { + onChunk?.(messages); + } + } + return messages; + } + + #extractMessages(data: any): v0_8.Types.ServerToClientMessage[] { + let items: any[] = []; + if (data.messages && Array.isArray(data.messages)) { + items = data.messages; + } else { + items = Array.isArray(data) + ? data + : (data.kind === "message" && Array.isArray(data.parts) ? data.parts : [data]); + } + + const messages: v0_8.Types.ServerToClientMessage[] = []; + for (const item of items) { + if (item.kind === "message" && Array.isArray(item.parts)) { + for (const part of item.parts) { + if (part.data) { + messages.push(part.data); + } + } + } else { + if (item.kind === "text") continue; + if (item.data) { + messages.push(item.data); + } + } + } + return messages; } } registerContactComponents(); diff --git a/samples/client/lit/custom-components-example/contact.ts b/samples/client/lit/custom-components-example/contact.ts index 18f50be86..283592062 100644 --- a/samples/client/lit/custom-components-example/contact.ts +++ b/samples/client/lit/custom-components-example/contact.ts @@ -186,6 +186,23 @@ export class A2UIContactFinder extends SignalWatcher(LitElement) { animation: spin 1s linear infinite; } + .rendering-indicator { + display: flex; + align-items: center; + justify-content: center; + padding: 16px; + color: var(--p-40); + font-size: 14px; + border-top: 1px solid var(--n-90); + margin-top: 16px; + width: 100%; + + & .g-icon { + margin-right: 8px; + font-size: 16px; + } + } + @keyframes spin { to { transform: rotate(360deg); @@ -242,9 +259,9 @@ export class A2UIContactFinder extends SignalWatcher(LitElement) { if (!body) { return; } - const message: v0_8.Types.A2UIClientEventMessage = { - request: body, - }; + const message: v0_8.Types.A2UIClientEventMessage = { + request: body, + }; await this.#sendAndProcessMessage(message); }} > @@ -289,89 +306,79 @@ export class A2UIContactFinder extends SignalWatcher(LitElement) { return nothing; } - return html`
- ${repeat( - surfaces, - ([surfaceId]) => surfaceId, - ([surfaceId, surface]) => { - return html` + return html` + ${this.#requesting + ? html`
+ progress_activity + Rendering UI... +
` + : nothing} +
+ ${repeat( + surfaces, + ([surfaceId]) => surfaceId, + ([surfaceId, surface]) => { + return html`
- ${this.#requesting && surfaceId === 'contact-card' ? html` -
- progress_activity -
- ` : nothing} + - ) => { - const [target] = evt.composedPath(); - if (!(target instanceof HTMLElement)) { - return; - } - - const context: v0_8.Types.A2UIClientEventMessage["userAction"]["context"] = - {}; - if (evt.detail.action.context) { - const srcContext = evt.detail.action.context; - for (const item of srcContext) { - if (item.value.literalBoolean) { - context[item.key] = item.value.literalBoolean; - } else if (item.value.literalNumber) { - context[item.key] = item.value.literalNumber; - } else if (item.value.literalString) { - context[item.key] = item.value.literalString; - } else if (item.value.path) { - const path = this.#processor.resolvePath( - item.value.path, - evt.detail.dataContextPath - ); - const value = this.#processor.getData( - evt.detail.sourceComponent, - path, - surfaceId - ); - context[item.key] = value; + evt: v0_8.Events.StateEvent<"a2ui.action"> + ) => { + const [target] = evt.composedPath(); + if (!(target instanceof HTMLElement)) { + return; } - } - } - const message: v0_8.Types.A2UIClientEventMessage = { - userAction: { - surfaceId: surfaceId, - name: "ACTION: " + evt.detail.action.name, - sourceComponentId: target.id, - timestamp: new Date().toISOString(), - context, - }, - }; + const context: v0_8.Types.A2UIClientEventMessage["userAction"]["context"] = + {}; + if (evt.detail.action.context) { + const srcContext = evt.detail.action.context; + for (const item of srcContext) { + if (item.value.literalBoolean) { + context[item.key] = item.value.literalBoolean; + } else if (item.value.literalNumber) { + context[item.key] = item.value.literalNumber; + } else if (item.value.literalString) { + context[item.key] = item.value.literalString; + } else if (item.value.path) { + const path = this.#processor.resolvePath( + item.value.path, + evt.detail.dataContextPath + ); + const value = this.#processor.getData( + evt.detail.sourceComponent, + path, + surfaceId + ); + context[item.key] = value; + } + } + } + const message: v0_8.Types.A2UIClientEventMessage = { + userAction: { + surfaceId: surfaceId, + name: "ACTION: " + evt.detail.action.name, + sourceComponentId: target.id, + timestamp: new Date().toISOString(), + context, + }, + }; - await this.#sendAndProcessMessage(message); - }} + + await this.#sendAndProcessMessage(message); + }} .surfaceId=${surfaceId} .processor=${this.#processor} .enableCustomElements=${true} >
`; - } - )} + } + )}
`; } @@ -381,8 +388,6 @@ export class A2UIContactFinder extends SignalWatcher(LitElement) { this.#lastMessages = messages; - // this.#processor.clearSurfaces(); // Removed to allow partial updates - this.#processor.processMessages(messages); this.renderVersion++; // Force re-render of surfaces this.requestUpdate(); @@ -397,7 +402,11 @@ export class A2UIContactFinder extends SignalWatcher(LitElement) { ): Promise { try { this.#requesting = true; - const response = await this.#a2uiClient.send(message); + const response = await this.#a2uiClient.send(message, (chunkMessages) => { + this.#processor.processMessages(chunkMessages); + this.renderVersion++; + this.requestUpdate(); + }); this.#requesting = false; diff --git a/samples/client/lit/custom-components-example/middleware/a2a.ts b/samples/client/lit/custom-components-example/middleware/a2a.ts index dbf9f93d7..e0f38eddd 100644 --- a/samples/client/lit/custom-components-example/middleware/a2a.ts +++ b/samples/client/lit/custom-components-example/middleware/a2a.ts @@ -26,6 +26,7 @@ import { import { v4 as uuidv4 } from "uuid"; const A2UI_MIME_TYPE = "application/json+a2ui"; +const enableStreaming = process.env["ENABLE_STREAMING"] === "true"; const fetchWithCustomHeader: typeof fetch = async (url, init) => { const headers = new Headers(init?.headers); @@ -119,27 +120,49 @@ export const plugin = (): Plugin => { } const client = await createOrGetClient(); - const response = await client.sendMessage(sendParams); - if ("error" in response) { - console.error("Error:", response.error.message); - res.statusCode = 500; - res.setHeader("Content-Type", "application/json"); - res.end(JSON.stringify({ error: response.error.message })); - return; - } else { - const result = (response as SendMessageSuccessResponse) - .result as Task; - if (result.kind === "task") { + + try { + if (enableStreaming) { + const stream = await client.sendMessageStream(sendParams); res.statusCode = 200; + res.setHeader("Content-Type", "text/event-stream"); + res.setHeader("Cache-Control", "no-cache"); + res.setHeader("Connection", "keep-alive"); + + for await (const chunk of stream) { + // A2AClient unpacks the JSON-RPC, so chunk is an A2AStreamEventData + if (chunk.kind === "status-update" && chunk.status.message?.parts) { + res.write(`data: ${JSON.stringify(chunk.status.message.parts)}\n\n`); + } else if (chunk.kind === "message" && chunk.parts) { + res.write(`data: ${JSON.stringify(chunk.parts)}\n\n`); + } + } + res.end(); + } else { + const response = await client.sendMessage(sendParams); + res.setHeader("Cache-Control", "no-store"); + if ("error" in response) { + res.statusCode = 500; + res.setHeader("Content-Type", "application/json"); + res.end(JSON.stringify({ error: response.error.message })); + } else { + const result = (response as SendMessageSuccessResponse).result as Task; + res.statusCode = 200; + res.setHeader("Content-Type", "application/json"); + res.end(JSON.stringify(result.kind === "task" ? result.status.message?.parts || [] : [])); + } + } + } catch (e: any) { + console.error("Error during streaming:", e); + if (!res.headersSent) { + res.statusCode = 500; res.setHeader("Content-Type", "application/json"); - res.end(JSON.stringify(result.status.message?.parts)); - return; + res.end(JSON.stringify({ error: e.message || String(e) })); + } else { + res.write(`data: ${JSON.stringify({ error: e.message || String(e) })}\n\n`); + res.end(); } } - - res.statusCode = 200; - res.setHeader("Content-Type", "application/json"); - res.end(JSON.stringify([])); }); return;