Skip to content

Commit 9aa575d

Browse files
authored
Merge branch 'main' into fix/handle-circular-refs
2 parents 3a1019d + 4a88804 commit 9aa575d

File tree

2 files changed

+231
-10
lines changed

2 files changed

+231
-10
lines changed

src/google/adk/memory/vertex_ai_memory_bank_service.py

Lines changed: 89 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@
6565
'wait_for_completion',
6666
})
6767

68+
_ENABLE_CONSOLIDATION_KEY = 'enable_consolidation'
69+
# Vertex docs for GenerateMemoriesRequest.DirectMemoriesSource allow
70+
# at most 5 direct_memories per request.
71+
_MAX_DIRECT_MEMORIES_PER_GENERATE_CALL = 5
72+
6873

6974
def _supports_generate_memories_metadata() -> bool:
7075
"""Returns whether installed Vertex SDK supports config.metadata."""
@@ -160,6 +165,11 @@ def __init__(
160165
not use Google AI Studio API key for this field. For more details, visit
161166
https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview
162167
"""
168+
if not agent_engine_id:
169+
raise ValueError(
170+
'agent_engine_id is required for VertexAiMemoryBankService.'
171+
)
172+
163173
self._project = project
164174
self._location = location
165175
self._agent_engine_id = agent_engine_id
@@ -219,7 +229,22 @@ async def add_memory(
219229
memories: Sequence[MemoryEntry],
220230
custom_metadata: Mapping[str, object] | None = None,
221231
) -> None:
222-
"""Adds explicit memory items via Vertex memories.create."""
232+
"""Adds explicit memory items using Vertex Memory Bank.
233+
234+
By default, this writes directly via `memories.create`.
235+
If `custom_metadata["enable_consolidation"]` is set to True, this uses
236+
`memories.generate` with `direct_memories_source` so provided memories are
237+
consolidated server-side.
238+
"""
239+
if _is_consolidation_enabled(custom_metadata):
240+
await self._add_memories_via_generate_direct_memories_source(
241+
app_name=app_name,
242+
user_id=user_id,
243+
memories=memories,
244+
custom_metadata=custom_metadata,
245+
)
246+
return
247+
223248
await self._add_memories_via_create(
224249
app_name=app_name,
225250
user_id=user_id,
@@ -235,9 +260,6 @@ async def _add_events_to_memory_from_events(
235260
events_to_process: Sequence[Event],
236261
custom_metadata: Mapping[str, object] | None = None,
237262
) -> None:
238-
if not self._agent_engine_id:
239-
raise ValueError('Agent Engine ID is required for Memory Bank.')
240-
241263
direct_events = []
242264
for event in events_to_process:
243265
if _should_filter_out_event(event.content):
@@ -272,9 +294,6 @@ async def _add_memories_via_create(
272294
custom_metadata: Mapping[str, object] | None = None,
273295
) -> None:
274296
"""Adds direct memory items without server-side extraction."""
275-
if not self._agent_engine_id:
276-
raise ValueError('Agent Engine ID is required for Memory Bank.')
277-
278297
normalized_memories = _normalize_memories_for_create(memories)
279298
api_client = self._get_api_client()
280299
for index, memory in enumerate(normalized_memories):
@@ -300,11 +319,41 @@ async def _add_memories_via_create(
300319
logger.info('Create memory response received.')
301320
logger.debug('Create memory response: %s', operation)
302321

322+
async def _add_memories_via_generate_direct_memories_source(
323+
self,
324+
*,
325+
app_name: str,
326+
user_id: str,
327+
memories: Sequence[MemoryEntry],
328+
custom_metadata: Mapping[str, object] | None = None,
329+
) -> None:
330+
"""Adds memories via generate API with direct_memories_source."""
331+
normalized_memories = _normalize_memories_for_create(memories)
332+
memory_texts = [
333+
_memory_entry_to_fact(m, index=i)
334+
for i, m in enumerate(normalized_memories)
335+
]
336+
api_client = self._get_api_client()
337+
config = _build_generate_memories_config(custom_metadata)
338+
for memory_batch in _iter_memory_batches(memory_texts):
339+
operation = await api_client.agent_engines.memories.generate(
340+
name='reasoningEngines/' + self._agent_engine_id,
341+
direct_memories_source={
342+
'direct_memories': [
343+
{'fact': memory_text} for memory_text in memory_batch
344+
]
345+
},
346+
scope={
347+
'app_name': app_name,
348+
'user_id': user_id,
349+
},
350+
config=config,
351+
)
352+
logger.info('Generate direct memory response received.')
353+
logger.debug('Generate direct memory response: %s', operation)
354+
303355
@override
304356
async def search_memory(self, *, app_name: str, user_id: str, query: str):
305-
if not self._agent_engine_id:
306-
raise ValueError('Agent Engine ID is required for Memory Bank.')
307-
308357
api_client = self._get_api_client()
309358
retrieved_memories_iterator = (
310359
await api_client.agent_engines.memories.retrieve(
@@ -379,6 +428,8 @@ def _build_generate_memories_config(
379428

380429
metadata_by_key: dict[str, object] = {}
381430
for key, value in custom_metadata.items():
431+
if key == _ENABLE_CONSOLIDATION_KEY:
432+
continue
382433
if key == 'ttl':
383434
if value is None:
384435
continue
@@ -456,6 +507,8 @@ def _build_create_memory_config(
456507
metadata_by_key: dict[str, object] = {}
457508
custom_revision_labels: dict[str, str] = {}
458509
for key, value in (custom_metadata or {}).items():
510+
if key == _ENABLE_CONSOLIDATION_KEY:
511+
continue
459512
if key == 'metadata':
460513
if value is None:
461514
continue
@@ -641,6 +694,32 @@ def _extract_revision_labels(
641694
return revision_labels
642695

643696

697+
def _is_consolidation_enabled(
698+
custom_metadata: Mapping[str, object] | None,
699+
) -> bool:
700+
"""Returns whether direct memories should be consolidated via generate API."""
701+
if not custom_metadata:
702+
return False
703+
enable_consolidation = custom_metadata.get(_ENABLE_CONSOLIDATION_KEY)
704+
if enable_consolidation is None:
705+
return False
706+
if not isinstance(enable_consolidation, bool):
707+
raise TypeError(
708+
f'custom_metadata["{_ENABLE_CONSOLIDATION_KEY}"] must be a bool.'
709+
)
710+
return enable_consolidation
711+
712+
713+
def _iter_memory_batches(memories: Sequence[str]) -> Sequence[Sequence[str]]:
714+
"""Returns memory slices that comply with direct_memories limits."""
715+
memory_batches: list[Sequence[str]] = []
716+
for index in range(0, len(memories), _MAX_DIRECT_MEMORIES_PER_GENERATE_CALL):
717+
memory_batches.append(
718+
memories[index : index + _MAX_DIRECT_MEMORIES_PER_GENERATE_CALL]
719+
)
720+
return memory_batches
721+
722+
644723
def _build_vertex_metadata(
645724
metadata_by_key: Mapping[str, object],
646725
) -> dict[str, object]:

tests/unittests/memory/test_vertex_ai_memory_bank_service.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,14 @@ async def test_initialize_with_project_location_and_api_key_error():
230230
)
231231

232232

233+
def test_initialize_without_agent_engine_id_error():
234+
with pytest.raises(
235+
ValueError,
236+
match='agent_engine_id is required for VertexAiMemoryBankService',
237+
):
238+
mock_vertex_ai_memory_bank_service(agent_engine_id=None)
239+
240+
233241
@pytest.mark.asyncio
234242
async def test_add_session_to_memory(mock_vertexai_client):
235243
memory_service = mock_vertex_ai_memory_bank_service()
@@ -481,6 +489,7 @@ async def test_add_memory_calls_create(
481489
),
482490
],
483491
custom_metadata={
492+
'enable_consolidation': False,
484493
'ttl': '6000s',
485494
'source': 'agent',
486495
},
@@ -518,6 +527,139 @@ async def test_add_memory_calls_create(
518527
vertex_common_types.AgentEngineMemoryConfig(**create_config)
519528

520529

530+
@pytest.mark.asyncio
531+
async def test_add_memory_enable_consolidation_calls_generate_direct_source(
532+
mock_vertexai_client,
533+
):
534+
memory_service = mock_vertex_ai_memory_bank_service()
535+
await memory_service.add_memory(
536+
app_name=MOCK_SESSION.app_name,
537+
user_id=MOCK_SESSION.user_id,
538+
memories=[
539+
MemoryEntry(
540+
content=types.Content(parts=[types.Part(text='fact one')])
541+
),
542+
MemoryEntry(
543+
content=types.Content(parts=[types.Part(text='fact two')])
544+
),
545+
],
546+
custom_metadata={
547+
'enable_consolidation': True,
548+
'source': 'agent',
549+
},
550+
)
551+
552+
expected_config = {'wait_for_completion': False}
553+
if _supports_generate_memories_metadata():
554+
expected_config['metadata'] = {'source': {'string_value': 'agent'}}
555+
556+
mock_vertexai_client.agent_engines.memories.generate.assert_called_once_with(
557+
name='reasoningEngines/123',
558+
direct_memories_source={
559+
'direct_memories': [
560+
{'fact': 'fact one'},
561+
{'fact': 'fact two'},
562+
]
563+
},
564+
scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
565+
config=expected_config,
566+
)
567+
mock_vertexai_client.agent_engines.memories.create.assert_not_called()
568+
569+
generate_config = (
570+
mock_vertexai_client.agent_engines.memories.generate.call_args.kwargs[
571+
'config'
572+
]
573+
)
574+
vertex_common_types.GenerateAgentEngineMemoriesConfig(**generate_config)
575+
576+
577+
@pytest.mark.asyncio
578+
async def test_add_memory_enable_consolidation_batches_generate_calls(
579+
mock_vertexai_client,
580+
):
581+
memory_service = mock_vertex_ai_memory_bank_service()
582+
await memory_service.add_memory(
583+
app_name=MOCK_SESSION.app_name,
584+
user_id=MOCK_SESSION.user_id,
585+
memories=[
586+
MemoryEntry(
587+
content=types.Content(parts=[types.Part(text='fact one')])
588+
),
589+
MemoryEntry(
590+
content=types.Content(parts=[types.Part(text='fact two')])
591+
),
592+
MemoryEntry(
593+
content=types.Content(parts=[types.Part(text='fact three')])
594+
),
595+
MemoryEntry(
596+
content=types.Content(parts=[types.Part(text='fact four')])
597+
),
598+
MemoryEntry(
599+
content=types.Content(parts=[types.Part(text='fact five')])
600+
),
601+
MemoryEntry(
602+
content=types.Content(parts=[types.Part(text='fact six')])
603+
),
604+
],
605+
custom_metadata={
606+
'enable_consolidation': True,
607+
},
608+
)
609+
610+
mock_vertexai_client.agent_engines.memories.generate.assert_has_awaits([
611+
mock.call(
612+
name='reasoningEngines/123',
613+
direct_memories_source={
614+
'direct_memories': [
615+
{'fact': 'fact one'},
616+
{'fact': 'fact two'},
617+
{'fact': 'fact three'},
618+
{'fact': 'fact four'},
619+
{'fact': 'fact five'},
620+
]
621+
},
622+
scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
623+
config={'wait_for_completion': False},
624+
),
625+
mock.call(
626+
name='reasoningEngines/123',
627+
direct_memories_source={
628+
'direct_memories': [
629+
{'fact': 'fact six'},
630+
]
631+
},
632+
scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
633+
config={'wait_for_completion': False},
634+
),
635+
])
636+
assert mock_vertexai_client.agent_engines.memories.generate.await_count == 2
637+
mock_vertexai_client.agent_engines.memories.create.assert_not_called()
638+
639+
640+
@pytest.mark.asyncio
641+
async def test_add_memory_invalid_enable_consolidation_type_raises(
642+
mock_vertexai_client,
643+
):
644+
memory_service = mock_vertex_ai_memory_bank_service()
645+
with pytest.raises(
646+
TypeError,
647+
match=r'custom_metadata\["enable_consolidation"\] must be a bool',
648+
):
649+
await memory_service.add_memory(
650+
app_name=MOCK_SESSION.app_name,
651+
user_id=MOCK_SESSION.user_id,
652+
memories=[
653+
MemoryEntry(
654+
content=types.Content(parts=[types.Part(text='fact one')])
655+
)
656+
],
657+
custom_metadata={'enable_consolidation': 'yes'},
658+
)
659+
mock_vertexai_client.agent_engines.memories.generate.assert_not_called()
660+
mock_vertexai_client.agent_engines.memories.create.assert_not_called()
661+
662+
521663
@pytest.mark.asyncio
522664
async def test_add_memory_calls_create_with_memory_entry_metadata(
523665
mock_vertexai_client,

0 commit comments

Comments
 (0)