Skip to content

Commit e71f941

Browse files
authored
Merge branch 'main' into feature/chroma-memory-service
2 parents 10d0852 + 663cb75 commit e71f941

File tree

4 files changed

+259
-32
lines changed

4 files changed

+259
-32
lines changed

src/google/adk/flows/llm_flows/contents.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,16 @@ async def run_async(
4242
self, invocation_context: InvocationContext, llm_request: LlmRequest
4343
) -> AsyncGenerator[Event, None]:
4444
from ...agents.llm_agent import LlmAgent
45+
from ...models.google_llm import Gemini
4546

4647
agent = invocation_context.agent
48+
preserve_function_call_ids = False
49+
if isinstance(agent, LlmAgent):
50+
canonical_model = agent.canonical_model
51+
preserve_function_call_ids = (
52+
isinstance(canonical_model, Gemini)
53+
and canonical_model.use_interactions_api
54+
)
4755

4856
# Preserve all contents that were added by instruction processor
4957
# (since llm_request.contents will be completely reassigned below)
@@ -55,13 +63,15 @@ async def run_async(
5563
invocation_context.branch,
5664
invocation_context.session.events,
5765
agent.name,
66+
preserve_function_call_ids=preserve_function_call_ids,
5867
)
5968
else:
6069
# Include current turn context only (no conversation history)
6170
llm_request.contents = _get_current_turn_contents(
6271
invocation_context.branch,
6372
invocation_context.session.events,
6473
agent.name,
74+
preserve_function_call_ids=preserve_function_call_ids,
6575
)
6676

6777
# Add instruction-related contents to proper position in conversation
@@ -360,7 +370,11 @@ def _process_compaction_events(events: list[Event]) -> list[Event]:
360370

361371

362372
def _get_contents(
363-
current_branch: Optional[str], events: list[Event], agent_name: str = ''
373+
current_branch: Optional[str],
374+
events: list[Event],
375+
agent_name: str = '',
376+
*,
377+
preserve_function_call_ids: bool = False,
364378
) -> list[types.Content]:
365379
"""Get the contents for the LLM request.
366380
@@ -370,6 +384,7 @@ def _get_contents(
370384
current_branch: The current branch of the agent.
371385
events: Events to process.
372386
agent_name: The name of the agent.
387+
preserve_function_call_ids: Whether to preserve function call ids.
373388
374389
Returns:
375390
A list of processed contents.
@@ -469,13 +484,18 @@ def _get_contents(
469484
for event in result_events:
470485
content = copy.deepcopy(event.content)
471486
if content:
472-
remove_client_function_call_id(content)
487+
if not preserve_function_call_ids:
488+
remove_client_function_call_id(content)
473489
contents.append(content)
474490
return contents
475491

476492

477493
def _get_current_turn_contents(
478-
current_branch: Optional[str], events: list[Event], agent_name: str = ''
494+
current_branch: Optional[str],
495+
events: list[Event],
496+
agent_name: str = '',
497+
*,
498+
preserve_function_call_ids: bool = False,
479499
) -> list[types.Content]:
480500
"""Get contents for the current turn only (no conversation history).
481501
@@ -491,6 +511,7 @@ def _get_current_turn_contents(
491511
current_branch: The current branch of the agent.
492512
events: A list of all session events.
493513
agent_name: The name of the agent.
514+
preserve_function_call_ids: Whether to preserve function call ids.
494515
495516
Returns:
496517
A list of contents for the current turn only, preserving context needed
@@ -502,7 +523,12 @@ def _get_current_turn_contents(
502523
if _should_include_event_in_context(current_branch, event) and (
503524
event.author == 'user' or _is_other_agent_reply(agent_name, event)
504525
):
505-
return _get_contents(current_branch, events[i:], agent_name)
526+
return _get_contents(
527+
current_branch,
528+
events[i:],
529+
agent_name,
530+
preserve_function_call_ids=preserve_function_call_ids,
531+
)
506532

507533
return []
508534

src/google/adk/memory/vertex_ai_memory_bank_service.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from .memory_entry import MemoryEntry
2828

2929
if TYPE_CHECKING:
30+
import vertexai
31+
3032
from ..sessions.session import Session
3133

3234
logger = logging.getLogger('google_adk.' + __name__)
@@ -88,8 +90,8 @@ async def add_session_to_memory(self, session: Session):
8890
'content': event.content.model_dump(exclude_none=True, mode='json')
8991
})
9092
if events:
91-
client = self._get_api_client()
92-
operation = client.agent_engines.memories.generate(
93+
api_client = self._get_api_client()
94+
operation = await api_client.agent_engines.memories.generate(
9395
name='reasoningEngines/' + self._agent_engine_id,
9496
direct_contents_source={'events': events},
9597
scope={
@@ -108,22 +110,24 @@ async def search_memory(self, *, app_name: str, user_id: str, query: str):
108110
if not self._agent_engine_id:
109111
raise ValueError('Agent Engine ID is required for Memory Bank.')
110112

111-
client = self._get_api_client()
112-
retrieved_memories_iterator = client.agent_engines.memories.retrieve(
113-
name='reasoningEngines/' + self._agent_engine_id,
114-
scope={
115-
'app_name': app_name,
116-
'user_id': user_id,
117-
},
118-
similarity_search_params={
119-
'search_query': query,
120-
},
113+
api_client = self._get_api_client()
114+
retrieved_memories_iterator = (
115+
await api_client.agent_engines.memories.retrieve(
116+
name='reasoningEngines/' + self._agent_engine_id,
117+
scope={
118+
'app_name': app_name,
119+
'user_id': user_id,
120+
},
121+
similarity_search_params={
122+
'search_query': query,
123+
},
124+
)
121125
)
122126

123127
logger.info('Search memory response received.')
124128

125-
memory_events = []
126-
for retrieved_memory in retrieved_memories_iterator:
129+
memory_events: list[MemoryEntry] = []
130+
async for retrieved_memory in retrieved_memories_iterator:
127131
# TODO: add more complex error handling
128132
logger.debug('Retrieved memory: %s', retrieved_memory)
129133
memory_events.append(
@@ -138,21 +142,22 @@ async def search_memory(self, *, app_name: str, user_id: str, query: str):
138142
)
139143
return SearchMemoryResponse(memories=memory_events)
140144

141-
def _get_api_client(self):
145+
def _get_api_client(self) -> vertexai.AsyncClient:
142146
"""Instantiates an API client for the given project and location.
143147
144148
It needs to be instantiated inside each request so that the event loop
145149
management can be properly propagated.
146150
Returns:
147-
An API client for the given project and location or express mode api key.
151+
An async API client for the given project and location or express mode api
152+
key.
148153
"""
149154
import vertexai
150155

151156
return vertexai.Client(
152157
project=self._project,
153158
location=self._location,
154159
api_key=self._express_mode_api_key,
155-
)
160+
).aio
156161

157162

158163
def _should_filter_out_event(content: types.Content) -> bool:

tests/unittests/flows/llm_flows/test_contents.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from google.adk.flows.llm_flows.contents import request_processor
2020
from google.adk.flows.llm_flows.functions import REQUEST_CONFIRMATION_FUNCTION_CALL_NAME
2121
from google.adk.flows.llm_flows.functions import REQUEST_EUC_FUNCTION_CALL_NAME
22+
from google.adk.models.google_llm import Gemini
2223
from google.adk.models.llm_request import LlmRequest
2324
from google.genai import types
2425
import pytest
@@ -931,3 +932,139 @@ async def test_function_response_with_thought_not_filtered():
931932
fr_parts = [p for p in fr_content.parts if p.function_response]
932933
assert len(fr_parts) == 1
933934
assert fr_parts[0].function_response.name == "calc_tool"
935+
936+
937+
@pytest.mark.asyncio
938+
async def test_adk_function_call_ids_are_stripped_for_non_interactions_model():
939+
"""Test ADK generated ids are removed for non-interactions requests."""
940+
agent = Agent(model="gemini-2.5-flash", name="test_agent")
941+
llm_request = LlmRequest(model="gemini-2.5-flash")
942+
invocation_context = await testing_utils.create_invocation_context(
943+
agent=agent
944+
)
945+
946+
function_call_id = "adk-test-call-id"
947+
events = [
948+
Event(
949+
invocation_id="inv1",
950+
author="user",
951+
content=types.UserContent("Call the tool"),
952+
),
953+
Event(
954+
invocation_id="inv2",
955+
author="test_agent",
956+
content=types.Content(
957+
role="model",
958+
parts=[
959+
types.Part(
960+
function_call=types.FunctionCall(
961+
id=function_call_id,
962+
name="test_tool",
963+
args={"x": 1},
964+
)
965+
)
966+
],
967+
),
968+
),
969+
Event(
970+
invocation_id="inv3",
971+
author="test_agent",
972+
content=types.Content(
973+
role="user",
974+
parts=[
975+
types.Part(
976+
function_response=types.FunctionResponse(
977+
id=function_call_id,
978+
name="test_tool",
979+
response={"result": 2},
980+
)
981+
)
982+
],
983+
),
984+
),
985+
]
986+
invocation_context.session.events = events
987+
988+
async for _ in contents.request_processor.run_async(
989+
invocation_context, llm_request
990+
):
991+
pass
992+
993+
model_fc_part = llm_request.contents[1].parts[0]
994+
assert model_fc_part.function_call is not None
995+
assert model_fc_part.function_call.id is None
996+
997+
user_fr_part = llm_request.contents[2].parts[0]
998+
assert user_fr_part.function_response is not None
999+
assert user_fr_part.function_response.id is None
1000+
1001+
1002+
@pytest.mark.asyncio
1003+
async def test_adk_function_call_ids_preserved_for_interactions_model():
1004+
"""Test ADK generated ids are preserved for interactions requests."""
1005+
agent = Agent(
1006+
model=Gemini(
1007+
model="gemini-2.5-flash",
1008+
use_interactions_api=True,
1009+
),
1010+
name="test_agent",
1011+
)
1012+
llm_request = LlmRequest(model="gemini-2.5-flash")
1013+
invocation_context = await testing_utils.create_invocation_context(
1014+
agent=agent
1015+
)
1016+
1017+
function_call_id = "adk-test-call-id"
1018+
events = [
1019+
Event(
1020+
invocation_id="inv1",
1021+
author="user",
1022+
content=types.UserContent("Call the tool"),
1023+
),
1024+
Event(
1025+
invocation_id="inv2",
1026+
author="test_agent",
1027+
content=types.Content(
1028+
role="model",
1029+
parts=[
1030+
types.Part(
1031+
function_call=types.FunctionCall(
1032+
id=function_call_id,
1033+
name="test_tool",
1034+
args={"x": 1},
1035+
)
1036+
)
1037+
],
1038+
),
1039+
),
1040+
Event(
1041+
invocation_id="inv3",
1042+
author="test_agent",
1043+
content=types.Content(
1044+
role="user",
1045+
parts=[
1046+
types.Part(
1047+
function_response=types.FunctionResponse(
1048+
id=function_call_id,
1049+
name="test_tool",
1050+
response={"result": 2},
1051+
)
1052+
)
1053+
],
1054+
),
1055+
),
1056+
]
1057+
invocation_context.session.events = events
1058+
1059+
async for _ in contents.request_processor.run_async(
1060+
invocation_context, llm_request
1061+
):
1062+
pass
1063+
1064+
model_fc_part = llm_request.contents[1].parts[0]
1065+
assert model_fc_part.function_call is not None
1066+
assert model_fc_part.function_call.id == function_call_id
1067+
1068+
user_fr_part = llm_request.contents[2].parts[0]
1069+
assert user_fr_part.function_response is not None
1070+
assert user_fr_part.function_response.id == function_call_id

0 commit comments

Comments
 (0)