Skip to content

Commit fbac3ea

Browse files
committed
Streaming support for restaurant
1 parent ea11012 commit fbac3ea

File tree

2 files changed

+81
-65
lines changed

2 files changed

+81
-65
lines changed

samples/agent/adk/restaurant_finder/agent.py

Lines changed: 73 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
Part,
2828
TextPart,
2929
)
30+
from google.adk.agents import run_config
3031
from google.adk.agents.llm_agent import LlmAgent
3132
from google.adk.artifacts import InMemoryArtifactService
3233
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
@@ -45,7 +46,11 @@
4546
from a2ui.core.parser.parser import parse_response, ResponsePart
4647
from a2ui.basic_catalog.provider import BasicCatalog
4748
from a2ui.core.schema.common_modifiers import remove_strict_validation
48-
from a2ui.a2a import create_a2ui_part, get_a2ui_agent_extension, parse_response_to_parts
49+
from a2ui.a2a import (
50+
get_a2ui_agent_extension,
51+
parse_response_to_parts,
52+
stream_response_to_parts,
53+
)
4954

5055
logger = logging.getLogger(__name__)
5156

@@ -71,6 +76,7 @@ def __init__(self, base_url: str, use_ui: bool = False):
7176
)
7277
self._agent = self._build_agent(use_ui)
7378
self._user_id = "remote_agent"
79+
self._parsers = {}
7480
self._runner = Runner(
7581
app_name=self._agent.name,
7682
agent=self._agent,
@@ -80,14 +86,17 @@ def __init__(self, base_url: str, use_ui: bool = False):
8086
)
8187

8288
def get_agent_card(self) -> AgentCard:
89+
extensions = []
90+
if self.use_ui:
91+
extensions.append(
92+
get_a2ui_agent_extension(
93+
self._schema_manager.accepts_inline_catalogs,
94+
self._schema_manager.supported_catalog_ids,
95+
)
96+
)
8397
capabilities = AgentCapabilities(
8498
streaming=True,
85-
extensions=[
86-
get_a2ui_agent_extension(
87-
self._schema_manager.accepts_inline_catalogs,
88-
self._schema_manager.supported_catalog_ids,
89-
)
90-
],
99+
extensions=extensions,
91100
)
92101
skill = AgentSkill(
93102
id="find_restaurants",
@@ -161,26 +170,28 @@ async def stream(self, query, session_id) -> AsyncIterable[dict[str, Any]]:
161170
current_query_text = query
162171

163172
# Ensure schema was loaded
164-
selected_catalog = self._schema_manager.get_selected_catalog()
165-
if self.use_ui and not selected_catalog.catalog_schema:
166-
logger.error(
167-
"--- RestaurantAgent.stream: A2UI_SCHEMA is not loaded. "
168-
"Cannot perform UI validation. ---"
169-
)
170-
yield {
171-
"is_task_complete": True,
172-
"parts": [
173-
Part(
174-
root=TextPart(
175-
text=(
176-
"I'm sorry, I'm facing an internal configuration error with"
177-
" my UI components. Please contact support."
178-
)
179-
)
180-
)
181-
],
182-
}
183-
return
173+
selected_catalog = None
174+
if self.use_ui:
175+
selected_catalog = self._schema_manager.get_selected_catalog()
176+
if not selected_catalog.catalog_schema:
177+
logger.error(
178+
"--- RestaurantAgent.stream: A2UI_SCHEMA is not loaded. "
179+
"Cannot perform UI validation. ---"
180+
)
181+
yield {
182+
"is_task_complete": True,
183+
"parts": [
184+
Part(
185+
root=TextPart(
186+
text=(
187+
"I'm sorry, I'm facing an internal configuration error with"
188+
" my UI components. Please contact support."
189+
)
190+
)
191+
)
192+
],
193+
}
194+
return
184195

185196
while attempt <= max_retries:
186197
attempt += 1
@@ -192,45 +203,46 @@ async def stream(self, query, session_id) -> AsyncIterable[dict[str, Any]]:
192203
current_message = types.Content(
193204
role="user", parts=[types.Part.from_text(text=current_query_text)]
194205
)
195-
final_response_content = None
196206

197-
async for event in self._runner.run_async(
198-
user_id=self._user_id,
199-
session_id=session.id,
200-
new_message=current_message,
201-
):
202-
logger.info(f"Event from runner: {event}")
203-
if event.is_final_response():
204-
if event.content and event.content.parts and event.content.parts[0].text:
205-
final_response_content = "\n".join(
206-
[p.text for p in event.content.parts if p.text]
207-
)
208-
break # Got the final response, stop consuming events
209-
else:
210-
logger.info(f"Intermediate event: {event}")
211-
# Yield intermediate updates on every attempt
207+
full_content_list = []
208+
209+
async def token_stream():
210+
async for event in self._runner.run_async(
211+
user_id=self._user_id,
212+
session_id=session.id,
213+
run_config=run_config.RunConfig(
214+
streaming_mode=run_config.StreamingMode.SSE
215+
),
216+
new_message=current_message,
217+
):
218+
if event.content and event.content.parts:
219+
for p in event.content.parts:
220+
if p.text:
221+
full_content_list.append(p.text)
222+
yield p.text
223+
224+
if self.use_ui:
225+
from a2ui.core.parser.streaming import A2uiStreamParser
226+
227+
if session_id not in self._parsers:
228+
self._parsers[session_id] = A2uiStreamParser(catalog=selected_catalog)
229+
230+
async for part in stream_response_to_parts(
231+
self._parsers[session_id],
232+
token_stream(),
233+
):
212234
yield {
213235
"is_task_complete": False,
214-
"updates": self.get_processing_message(),
236+
"parts": [part],
237+
}
238+
else:
239+
async for token in token_stream():
240+
yield {
241+
"is_task_complete": False,
242+
"updates": token,
215243
}
216244

217-
if final_response_content is None:
218-
logger.warning(
219-
"--- RestaurantAgent.stream: Received no final response content from"
220-
f" runner (Attempt {attempt}). ---"
221-
)
222-
if attempt <= max_retries:
223-
current_query_text = (
224-
"I received no response. Please try again."
225-
f"Please retry the original request: '{query}'"
226-
)
227-
continue # Go to next retry
228-
else:
229-
# Retries exhausted on no-response
230-
final_response_content = (
231-
"I'm sorry, I encountered an error and couldn't process your request."
232-
)
233-
# Fall through to send this as a text-only error
245+
final_response_content = "".join(full_content_list)
234246

235247
is_valid = False
236248
error_message = ""

samples/agent/adk/restaurant_finder/agent_executor.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,14 @@ async def execute(
129129
async for item in agent.stream(query, task.context_id):
130130
is_task_complete = item["is_task_complete"]
131131
if not is_task_complete:
132-
await updater.update_status(
133-
TaskState.working,
134-
new_agent_text_message(item["updates"], task.context_id, task.id),
135-
)
132+
message = None
133+
if "parts" in item:
134+
message = new_agent_parts_message(item["parts"], task.context_id, task.id)
135+
elif "updates" in item:
136+
message = new_agent_text_message(item["updates"], task.context_id, task.id)
137+
138+
if message:
139+
await updater.update_status(TaskState.working, message)
136140
continue
137141

138142
final_state = (

0 commit comments

Comments
 (0)