Skip to content

Commit a0f8f38

Browse files
committed
Add UT
1 parent 247c913 commit a0f8f38

File tree

1 file changed

+85
-0
lines changed

1 file changed

+85
-0
lines changed

tests/unittests/test_runners.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,5 +1237,90 @@ def test_infer_agent_origin_detects_mismatch_for_user_agent(
12371237
assert "actual_name" in runner._app_name_alignment_hint
12381238

12391239

1240+
@pytest.mark.asyncio
1241+
async def test_temp_state_accessible_in_callbacks_but_not_persisted():
1242+
"""Tests that temp: state variables are accessible during lifecycle callbacks
1243+
but not persisted in the session."""
1244+
1245+
# Track what state was seen during callbacks
1246+
state_seen_in_before_agent = {}
1247+
state_seen_in_after_agent = {}
1248+
1249+
class StateAccessPlugin(BasePlugin):
1250+
"""Plugin that accesses state during callbacks."""
1251+
1252+
async def before_agent_callback(self, *, agent, callback_context):
1253+
# Set a temp state variable
1254+
callback_context.state["temp:test_key"] = "test_value"
1255+
callback_context.state["normal_key"] = "normal_value"
1256+
1257+
# Verify we can read it back immediately
1258+
state_seen_in_before_agent["temp:test_key"] = callback_context.state.get(
1259+
"temp:test_key"
1260+
)
1261+
state_seen_in_before_agent["normal_key"] = callback_context.state.get(
1262+
"normal_key"
1263+
)
1264+
return None
1265+
1266+
async def after_agent_callback(self, *, agent, callback_context):
1267+
# Verify temp state is still accessible during the same invocation
1268+
state_seen_in_after_agent["temp:test_key"] = callback_context.state.get(
1269+
"temp:test_key"
1270+
)
1271+
state_seen_in_after_agent["normal_key"] = callback_context.state.get(
1272+
"normal_key"
1273+
)
1274+
return None
1275+
1276+
# Setup
1277+
session_service = InMemorySessionService()
1278+
plugin = StateAccessPlugin(name="state_access")
1279+
1280+
agent = MockAgent(name="test_agent")
1281+
runner = Runner(
1282+
app_name=TEST_APP_ID,
1283+
agent=agent,
1284+
session_service=session_service,
1285+
plugins=[plugin],
1286+
auto_create_session=True,
1287+
)
1288+
1289+
# Run the agent
1290+
events = []
1291+
async for event in runner.run_async(
1292+
user_id=TEST_USER_ID,
1293+
session_id=TEST_SESSION_ID,
1294+
new_message=types.Content(
1295+
role="user", parts=[types.Part(text="test message")]
1296+
),
1297+
):
1298+
events.append(event)
1299+
1300+
# Verify temp state was accessible during callbacks
1301+
assert state_seen_in_before_agent["temp:test_key"] == "test_value"
1302+
assert state_seen_in_before_agent["normal_key"] == "normal_value"
1303+
assert state_seen_in_after_agent["temp:test_key"] == "test_value"
1304+
assert state_seen_in_after_agent["normal_key"] == "normal_value"
1305+
1306+
# Verify temp state is NOT persisted in the session
1307+
session = await session_service.get_session(
1308+
app_name=TEST_APP_ID,
1309+
user_id=TEST_USER_ID,
1310+
session_id=TEST_SESSION_ID,
1311+
)
1312+
1313+
# Normal state should be persisted
1314+
assert session.state.get("normal_key") == "normal_value"
1315+
1316+
# Temp state should NOT be persisted
1317+
assert "temp:test_key" not in session.state
1318+
1319+
# Verify temp state is also not in any event's state_delta
1320+
for event in session.events:
1321+
if event.actions and event.actions.state_delta:
1322+
assert "temp:test_key" not in event.actions.state_delta
1323+
1324+
12401325
if __name__ == "__main__":
12411326
pytest.main([__file__])

0 commit comments

Comments
 (0)