6565 'wait_for_completion' ,
6666})
6767
68+ _ENABLE_CONSOLIDATION_KEY = 'enable_consolidation'
69+ # Vertex docs for GenerateMemoriesRequest.DirectMemoriesSource allow
70+ # at most 5 direct_memories per request.
71+ _MAX_DIRECT_MEMORIES_PER_GENERATE_CALL = 5
72+
6873
6974def _supports_generate_memories_metadata () -> bool :
7075 """Returns whether installed Vertex SDK supports config.metadata."""
@@ -160,6 +165,11 @@ def __init__(
160165 not use Google AI Studio API key for this field. For more details, visit
161166 https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview
162167 """
168+ if not agent_engine_id :
169+ raise ValueError (
170+ 'agent_engine_id is required for VertexAiMemoryBankService.'
171+ )
172+
163173 self ._project = project
164174 self ._location = location
165175 self ._agent_engine_id = agent_engine_id
@@ -219,7 +229,22 @@ async def add_memory(
219229 memories : Sequence [MemoryEntry ],
220230 custom_metadata : Mapping [str , object ] | None = None ,
221231 ) -> None :
222- """Adds explicit memory items via Vertex memories.create."""
232+ """Adds explicit memory items using Vertex Memory Bank.
233+
234+ By default, this writes directly via `memories.create`.
235+ If `custom_metadata["enable_consolidation"]` is set to True, this uses
236+ `memories.generate` with `direct_memories_source` so provided memories are
237+ consolidated server-side.
238+ """
239+ if _is_consolidation_enabled (custom_metadata ):
240+ await self ._add_memories_via_generate_direct_memories_source (
241+ app_name = app_name ,
242+ user_id = user_id ,
243+ memories = memories ,
244+ custom_metadata = custom_metadata ,
245+ )
246+ return
247+
223248 await self ._add_memories_via_create (
224249 app_name = app_name ,
225250 user_id = user_id ,
@@ -235,9 +260,6 @@ async def _add_events_to_memory_from_events(
235260 events_to_process : Sequence [Event ],
236261 custom_metadata : Mapping [str , object ] | None = None ,
237262 ) -> None :
238- if not self ._agent_engine_id :
239- raise ValueError ('Agent Engine ID is required for Memory Bank.' )
240-
241263 direct_events = []
242264 for event in events_to_process :
243265 if _should_filter_out_event (event .content ):
@@ -272,9 +294,6 @@ async def _add_memories_via_create(
272294 custom_metadata : Mapping [str , object ] | None = None ,
273295 ) -> None :
274296 """Adds direct memory items without server-side extraction."""
275- if not self ._agent_engine_id :
276- raise ValueError ('Agent Engine ID is required for Memory Bank.' )
277-
278297 normalized_memories = _normalize_memories_for_create (memories )
279298 api_client = self ._get_api_client ()
280299 for index , memory in enumerate (normalized_memories ):
@@ -300,11 +319,41 @@ async def _add_memories_via_create(
300319 logger .info ('Create memory response received.' )
301320 logger .debug ('Create memory response: %s' , operation )
302321
322+ async def _add_memories_via_generate_direct_memories_source (
323+ self ,
324+ * ,
325+ app_name : str ,
326+ user_id : str ,
327+ memories : Sequence [MemoryEntry ],
328+ custom_metadata : Mapping [str , object ] | None = None ,
329+ ) -> None :
330+ """Adds memories via generate API with direct_memories_source."""
331+ normalized_memories = _normalize_memories_for_create (memories )
332+ memory_texts = [
333+ _memory_entry_to_fact (m , index = i )
334+ for i , m in enumerate (normalized_memories )
335+ ]
336+ api_client = self ._get_api_client ()
337+ config = _build_generate_memories_config (custom_metadata )
338+ for memory_batch in _iter_memory_batches (memory_texts ):
339+ operation = await api_client .agent_engines .memories .generate (
340+ name = 'reasoningEngines/' + self ._agent_engine_id ,
341+ direct_memories_source = {
342+ 'direct_memories' : [
343+ {'fact' : memory_text } for memory_text in memory_batch
344+ ]
345+ },
346+ scope = {
347+ 'app_name' : app_name ,
348+ 'user_id' : user_id ,
349+ },
350+ config = config ,
351+ )
352+ logger .info ('Generate direct memory response received.' )
353+ logger .debug ('Generate direct memory response: %s' , operation )
354+
303355 @override
304356 async def search_memory (self , * , app_name : str , user_id : str , query : str ):
305- if not self ._agent_engine_id :
306- raise ValueError ('Agent Engine ID is required for Memory Bank.' )
307-
308357 api_client = self ._get_api_client ()
309358 retrieved_memories_iterator = (
310359 await api_client .agent_engines .memories .retrieve (
@@ -379,6 +428,8 @@ def _build_generate_memories_config(
379428
380429 metadata_by_key : dict [str , object ] = {}
381430 for key , value in custom_metadata .items ():
431+ if key == _ENABLE_CONSOLIDATION_KEY :
432+ continue
382433 if key == 'ttl' :
383434 if value is None :
384435 continue
@@ -456,6 +507,8 @@ def _build_create_memory_config(
456507 metadata_by_key : dict [str , object ] = {}
457508 custom_revision_labels : dict [str , str ] = {}
458509 for key , value in (custom_metadata or {}).items ():
510+ if key == _ENABLE_CONSOLIDATION_KEY :
511+ continue
459512 if key == 'metadata' :
460513 if value is None :
461514 continue
@@ -641,6 +694,32 @@ def _extract_revision_labels(
641694 return revision_labels
642695
643696
697+ def _is_consolidation_enabled (
698+ custom_metadata : Mapping [str , object ] | None ,
699+ ) -> bool :
700+ """Returns whether direct memories should be consolidated via generate API."""
701+ if not custom_metadata :
702+ return False
703+ enable_consolidation = custom_metadata .get (_ENABLE_CONSOLIDATION_KEY )
704+ if enable_consolidation is None :
705+ return False
706+ if not isinstance (enable_consolidation , bool ):
707+ raise TypeError (
708+ f'custom_metadata["{ _ENABLE_CONSOLIDATION_KEY } "] must be a bool.'
709+ )
710+ return enable_consolidation
711+
712+
713+ def _iter_memory_batches (memories : Sequence [str ]) -> Sequence [Sequence [str ]]:
714+ """Returns memory slices that comply with direct_memories limits."""
715+ memory_batches : list [Sequence [str ]] = []
716+ for index in range (0 , len (memories ), _MAX_DIRECT_MEMORIES_PER_GENERATE_CALL ):
717+ memory_batches .append (
718+ memories [index : index + _MAX_DIRECT_MEMORIES_PER_GENERATE_CALL ]
719+ )
720+ return memory_batches
721+
722+
644723def _build_vertex_metadata (
645724 metadata_by_key : Mapping [str , object ],
646725) -> dict [str , object ]:
0 commit comments