Skip to content

Commit 32f2ec3

Browse files
XinranTangcopybara-github
authored andcommitted
feat: Set agent_state in invocation context right before yielding the checkpoint event
PiperOrigin-RevId: 816804179
1 parent 7517924 commit 32f2ec3

File tree

9 files changed

+113
-39
lines changed

9 files changed

+113
-39
lines changed

src/google/adk/agents/base_agent.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -184,21 +184,19 @@ def _load_agent_state(
184184
def _create_agent_state_event(
185185
self,
186186
ctx: InvocationContext,
187-
*,
188-
agent_state: Optional[BaseAgentState] = None,
189-
end_of_agent: bool = False,
190187
) -> Event:
191-
"""Returns an event with agent state.
188+
"""Returns an event with current agent state set in the invocation context.
192189
193190
Args:
194191
ctx: The invocation context.
195-
agent_state: The agent state to checkpoint.
196-
end_of_agent: Whether the agent is finished running.
192+
193+
Returns:
194+
An event with the current agent state set in the invocation context.
197195
"""
198196
event_actions = EventActions()
199-
if agent_state:
200-
event_actions.agent_state = agent_state.model_dump(mode='json')
201-
if end_of_agent:
197+
if (agent_state := ctx.agent_states.get(self.name)) is not None:
198+
event_actions.agent_state = agent_state
199+
if ctx.end_of_agents.get(self.name):
202200
event_actions.end_of_agent = True
203201
return Event(
204202
invocation_id=ctx.invocation_id,

src/google/adk/agents/invocation_context.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,37 @@ def is_resumable(self) -> bool:
217217
and self.resumability_config.is_resumable
218218
)
219219

220-
def reset_agent_state(self, agent_name: str) -> None:
221-
"""Resets the state of an agent, allowing it to be re-run."""
222-
self.agent_states.pop(agent_name, None)
223-
self.end_of_agents.pop(agent_name, None)
220+
def set_agent_state(
221+
self,
222+
agent_name: str,
223+
*,
224+
agent_state: Optional[BaseAgentState] = None,
225+
end_of_agent: bool = False,
226+
) -> None:
227+
"""Sets the state of an agent in this invocation.
228+
229+
* If end_of_agent is True, will set the end_of_agent flag to True and
230+
clear the agent_state.
231+
* Otherwise, if agent_state is not None, will set the agent_state and
232+
reset the end_of_agent flag to False.
233+
* Otherwise, will clear the agent_state and end_of_agent flag, to allow the
234+
agent to re-run.
235+
236+
Args:
237+
agent_name: The name of the agent.
238+
agent_state: The state of the agent. Will be ignored if end_of_agent is
239+
True.
240+
end_of_agent: Whether the agent has finished running.
241+
"""
242+
if end_of_agent:
243+
self.end_of_agents[agent_name] = True
244+
self.agent_states.pop(agent_name, None)
245+
elif agent_state is not None:
246+
self.agent_states[agent_name] = agent_state.model_dump(mode="json")
247+
self.end_of_agents[agent_name] = False
248+
else:
249+
self.end_of_agents.pop(agent_name, None)
250+
self.agent_states.pop(agent_name, None)
224251

225252
def populate_invocation_agent_states(self) -> None:
226253
"""Populates agent states for the current invocation if it is resumable.

src/google/adk/agents/llm_agent.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,8 @@ async def _run_async_impl(
388388
async for event in agen:
389389
yield event
390390

391-
yield self._create_agent_state_event(ctx, end_of_agent=True)
391+
ctx.set_agent_state(self.name, end_of_agent=True)
392+
yield self._create_agent_state_event(ctx)
392393
return
393394

394395
async with Aclosing(self._llm_flow.run_async(ctx)) as agen:
@@ -399,7 +400,8 @@ async def _run_async_impl(
399400
return
400401

401402
if ctx.is_resumable:
402-
yield self._create_agent_state_event(ctx, end_of_agent=True)
403+
ctx.set_agent_state(self.name, end_of_agent=True)
404+
yield self._create_agent_state_event(ctx)
403405

404406
@override
405407
async def _run_live_impl(

src/google/adk/agents/loop_agent.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,13 @@ async def _run_async_impl(
9191
current_sub_agent=sub_agent.name,
9292
times_looped=times_looped,
9393
)
94-
yield self._create_agent_state_event(ctx, agent_state=agent_state)
94+
ctx.set_agent_state(self.name, agent_state=agent_state)
95+
yield self._create_agent_state_event(ctx)
9596

9697
# Reset the sub-agent's state in the context to ensure that each
9798
# sub-agent starts fresh.
9899
if not is_resuming_at_current_agent:
99-
ctx.reset_agent_state(sub_agent.name)
100+
ctx.set_agent_state(sub_agent.name)
100101
is_resuming_at_current_agent = False
101102

102103
async with Aclosing(sub_agent.run_async(ctx)) as agen:
@@ -119,7 +120,8 @@ async def _run_async_impl(
119120
return
120121

121122
if ctx.is_resumable:
122-
yield self._create_agent_state_event(ctx, end_of_agent=True)
123+
ctx.set_agent_state(self.name, end_of_agent=True)
124+
yield self._create_agent_state_event(ctx)
123125

124126
def _get_start_state(
125127
self,

src/google/adk/agents/parallel_agent.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,14 +181,15 @@ async def _run_async_impl(
181181

182182
agent_state = self._load_agent_state(ctx, BaseAgentState)
183183
if ctx.is_resumable and agent_state is None:
184-
yield self._create_agent_state_event(ctx, agent_state=BaseAgentState())
184+
ctx.set_agent_state(self.name, agent_state=BaseAgentState())
185+
yield self._create_agent_state_event(ctx)
185186

186187
agent_runs = []
187188
# Prepare and collect async generators for each sub-agent.
188189
for sub_agent in self.sub_agents:
189190
if agent_state is None:
190191
# Reset sub-agent state to make sure each sub-agent starts fresh.
191-
ctx.reset_agent_state(sub_agent.name)
192+
ctx.set_agent_state(sub_agent.name)
192193

193194
sub_agent_ctx = _create_branch_ctx_for_sub_agent(self, sub_agent, ctx)
194195

@@ -215,8 +216,11 @@ async def _run_async_impl(
215216
return
216217

217218
# Once all sub-agents are done, mark the ParallelAgent as final.
218-
if ctx.is_resumable:
219-
yield self._create_agent_state_event(ctx, end_of_agent=True)
219+
if ctx.is_resumable and all(
220+
ctx.end_of_agents.get(sub_agent.name) for sub_agent in self.sub_agents
221+
):
222+
ctx.set_agent_state(self.name, end_of_agent=True)
223+
yield self._create_agent_state_event(ctx)
220224

221225
finally:
222226
for sub_agent_run in agent_runs:

src/google/adk/agents/sequential_agent.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,12 @@ async def _run_async_impl(
7070
# already been logged, so we should avoid yielding it again.
7171
if ctx.is_resumable:
7272
agent_state = SequentialAgentState(current_sub_agent=sub_agent.name)
73-
yield self._create_agent_state_event(ctx, agent_state=agent_state)
73+
ctx.set_agent_state(self.name, agent_state=agent_state)
74+
yield self._create_agent_state_event(ctx)
7475

7576
# Reset the sub-agent's state in the context to ensure that each
7677
# sub-agent starts fresh.
77-
ctx.reset_agent_state(sub_agent.name)
78+
ctx.set_agent_state(sub_agent.name)
7879

7980
async with Aclosing(sub_agent.run_async(ctx)) as agen:
8081
async for event in agen:
@@ -90,7 +91,8 @@ async def _run_async_impl(
9091
resuming_sub_agent = False
9192

9293
if ctx.is_resumable:
93-
yield self._create_agent_state_event(ctx, end_of_agent=True)
94+
ctx.set_agent_state(self.name, end_of_agent=True)
95+
yield self._create_agent_state_event(ctx)
9496

9597
def _get_start_index(
9698
self,

tests/unittests/agents/test_base_agent.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -929,9 +929,10 @@ async def test_create_agent_state_event():
929929

930930
ctx.branch = 'test_branch'
931931

932-
# Test case 1: with state
932+
# Test case 1: set agent state in context
933933
state = _TestAgentState(test_field='checkpoint')
934-
event = agent._create_agent_state_event(ctx, agent_state=state)
934+
ctx.set_agent_state(agent.name, agent_state=state)
935+
event = agent._create_agent_state_event(ctx)
935936
assert event is not None
936937
assert event.invocation_id == ctx.invocation_id
937938
assert event.author == agent.name
@@ -941,8 +942,9 @@ async def test_create_agent_state_event():
941942
assert event.actions.agent_state == state.model_dump(mode='json')
942943
assert not event.actions.end_of_agent
943944

944-
# Test case 2: with end_of_agent
945-
event = agent._create_agent_state_event(ctx, end_of_agent=True)
945+
# Test case 2: set end_of_agent in context
946+
ctx.set_agent_state(agent.name, end_of_agent=True)
947+
event = agent._create_agent_state_event(ctx)
946948
assert event is not None
947949
assert event.invocation_id == ctx.invocation_id
948950
assert event.author == agent.name
@@ -951,16 +953,8 @@ async def test_create_agent_state_event():
951953
assert event.actions.end_of_agent
952954
assert event.actions.agent_state is None
953955

954-
# Test case 3: with both state and end_of_agent
955-
state = _TestAgentState(test_field='checkpoint')
956-
event = agent._create_agent_state_event(
957-
ctx, agent_state=state, end_of_agent=True
958-
)
959-
assert event is not None
960-
assert event.actions.agent_state == state.model_dump(mode='json')
961-
assert event.actions.end_of_agent
962-
963-
# Test case 4: with neither
956+
# Test case 3: reset agent state and end_of_agent in context
957+
ctx.set_agent_state(agent.name)
964958
event = agent._create_agent_state_event(ctx)
965959
assert event is not None
966960
assert event.actions.agent_state is None

tests/unittests/agents/test_invocation_context.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,49 @@ def test_populate_invocation_agent_states_no_content(self):
347347
assert not invocation_context.agent_states
348348
assert not invocation_context.end_of_agents
349349

350+
def test_set_agent_state_with_end_of_agent_true(self):
351+
"""Tests that set_agent_state clears agent_state and sets end_of_agent to True."""
352+
invocation_context = self._create_test_invocation_context(
353+
ResumabilityConfig(is_resumable=True)
354+
)
355+
invocation_context.agent_states['agent1'] = {}
356+
invocation_context.end_of_agents['agent1'] = False
357+
358+
# Set state with end_of_agent=True, which should clear the existing
359+
# agent_state.
360+
invocation_context.set_agent_state('agent1', end_of_agent=True)
361+
assert 'agent1' not in invocation_context.agent_states
362+
assert invocation_context.end_of_agents['agent1']
363+
364+
def test_set_agent_state_with_agent_state(self):
365+
"""Tests that set_agent_state sets agent_state and sets end_of_agent to False."""
366+
agent_state = BaseAgentState()
367+
invocation_context = self._create_test_invocation_context(
368+
ResumabilityConfig(is_resumable=True)
369+
)
370+
invocation_context.end_of_agents['agent1'] = True
371+
372+
# Set state with agent_state=agent_state, which should set the agent_state
373+
# and reset the end_of_agent flag to False.
374+
invocation_context.set_agent_state('agent1', agent_state=agent_state)
375+
assert invocation_context.agent_states['agent1'] == agent_state.model_dump(
376+
mode='json'
377+
)
378+
assert invocation_context.end_of_agents['agent1'] is False
379+
380+
def test_reset_agent_state(self):
381+
"""Tests that set_agent_state clears agent_state and end_of_agent."""
382+
invocation_context = self._create_test_invocation_context(
383+
ResumabilityConfig(is_resumable=True)
384+
)
385+
invocation_context.agent_states['agent1'] = {}
386+
invocation_context.end_of_agents['agent1'] = True
387+
388+
# Reset state, which should clear the agent_state and end_of_agent flag.
389+
invocation_context.set_agent_state('agent1')
390+
assert 'agent1' not in invocation_context.agent_states
391+
assert 'agent1' not in invocation_context.end_of_agents
392+
350393

351394
class TestFindMatchingFunctionCall:
352395
"""Test suite for find_matching_function_call."""

tests/unittests/agents/test_parallel_agent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ async def _run_async_impl(
5252
) -> AsyncGenerator[Event, None]:
5353
await asyncio.sleep(self.delay)
5454
yield self.event(ctx)
55+
if ctx.is_resumable:
56+
ctx.set_agent_state(self.name, end_of_agent=True)
5557

5658

5759
async def _create_parent_invocation_context(

0 commit comments

Comments
 (0)