Skip to content
Open
93 changes: 72 additions & 21 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def __init__(
memory_service: Optional[BaseMemoryService] = None,
credential_service: Optional[BaseCredentialService] = None,
plugin_close_timeout: float = 5.0,
auto_create_session: bool = False,
):
"""Initializes the Runner.

Expand All @@ -175,6 +176,9 @@ def __init__(
memory_service: The memory service for the runner.
credential_service: The credential service for the runner.
plugin_close_timeout: The timeout in seconds for plugin close methods.
auto_create_session: Whether to automatically create a session when
not found. Defaults to False. If False, a missing session raises
ValueError with a helpful message.

Raises:
ValueError: If `app` is provided along with `agent` or `plugins`, or if
Expand All @@ -195,6 +199,7 @@ def __init__(
self.plugin_manager = PluginManager(
plugins=plugins, close_timeout=plugin_close_timeout
)
self.auto_create_session = auto_create_session
(
self._agent_origin_app_name,
self._agent_origin_dir,
Expand Down Expand Up @@ -343,9 +348,56 @@ def _format_session_not_found_message(self, session_id: str) -> str:
return message
return (
f'{message}. {self._app_name_alignment_hint} '
'The mismatch prevents the runner from locating the session.'
'The mismatch prevents the runner from locating the session. '
'To automatically create a session when missing, set '
'auto_create_session=True when constructing the runner.'
)

async def _get_or_create_session(
self, *, user_id: str, session_id: str
) -> Session:
"""Gets the session or creates it if missing.

This helper first attempts to retrieve the session. If not found, it
creates a new session with the provided identifiers.

Args:
user_id: The user ID of the session.
session_id: The session ID of the session.

Returns:
The existing or newly created `Session`.
"""
session = await self.session_service.get_session(
app_name=self.app_name, user_id=user_id, session_id=session_id
)
if session:
return session
return await self.session_service.create_session(
app_name=self.app_name, user_id=user_id, session_id=session_id
)

async def _get_session(self, *, user_id: str, session_id: str) -> Session:
"""Gets a session or raises a helpful error if missing.

Args:
user_id: The user ID of the session.
session_id: The session ID of the session.

Returns:
The existing `Session`.

Raises:
ValueError: If the session cannot be found.
"""
session = await self.session_service.get_session(
app_name=self.app_name, user_id=user_id, session_id=session_id
)
if not session:
message = self._format_session_not_found_message(session_id)
raise ValueError(message)
return session

def run(
self,
*,
Expand Down Expand Up @@ -455,12 +507,14 @@ async def _run_with_trace(
invocation_id: Optional[str] = None,
) -> AsyncGenerator[Event, None]:
with tracer.start_as_current_span('invocation'):
session = await self.session_service.get_session(
app_name=self.app_name, user_id=user_id, session_id=session_id
)
if not session:
message = self._format_session_not_found_message(session_id)
raise ValueError(message)
if self.auto_create_session:
session = await self._get_or_create_session(
user_id=user_id, session_id=session_id
)
else:
session = await self._get_session(
user_id=user_id, session_id=session_id
)
if not invocation_id and not new_message:
raise ValueError(
'Running an agent requires either a new_message or an '
Expand Down Expand Up @@ -534,12 +588,12 @@ async def rewind_async(
rewind_before_invocation_id: str,
) -> None:
"""Rewinds the session to before the specified invocation."""
session = await self.session_service.get_session(
app_name=self.app_name, user_id=user_id, session_id=session_id
)
if not session:
raise ValueError(f'Session not found: {session_id}')

if self.auto_create_session:
session = await self._get_or_create_session(
user_id=user_id, session_id=session_id
)
else:
session = await self._get_session(user_id=user_id, session_id=session_id)
rewind_event_index = -1
for i, event in enumerate(session.events):
if event.invocation_id == rewind_before_invocation_id:
Expand Down Expand Up @@ -966,15 +1020,12 @@ async def run_live(
DeprecationWarning,
stacklevel=2,
)
if not session:
session = await self.session_service.get_session(
app_name=self.app_name, user_id=user_id, session_id=session_id
if self.auto_create_session:
session = await self._get_or_create_session(
user_id=user_id, session_id=session_id
)
if not session:
raise ValueError(
f'Session not found for user id: {user_id} and session id:'
f' {session_id}'
)
else:
session = await self._get_session(user_id=user_id, session_id=session_id)
invocation_context = self._new_invocation_context_for_live(
session,
live_request_queue=live_request_queue,
Expand Down
4 changes: 2 additions & 2 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,8 @@ async def append_event(self, session: Session, event: Event) -> Event:
if storage_session.update_timestamp_tz > session.last_update_time:
raise ValueError(
"The last_update_time provided in the session object"
f" {datetime.fromtimestamp(session.last_update_time):'%Y-%m-%d %H:%M:%S'} is"
" earlier than the update_time in the storage_session"
f" {datetime.fromtimestamp(session.last_update_time):'%Y-%m-%d %H:%M:%S'}"
" is earlier than the update_time in the storage_session"
f" {datetime.fromtimestamp(storage_session.update_timestamp_tz):'%Y-%m-%d %H:%M:%S'}."
" Please check if it is a stale session."
)
Expand Down
121 changes: 121 additions & 0 deletions tests/unittests/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,24 @@ async def _run_async_impl(
)


class MockLiveAgent(BaseAgent):
"""Mock live agent for unit testing."""

def __init__(self, name: str):
super().__init__(name=name, sub_agents=[])

async def _run_live_impl(
self, invocation_context: InvocationContext
) -> AsyncGenerator[Event, None]:
yield Event(
invocation_id=invocation_context.invocation_id,
author=self.name,
content=types.Content(
role="model", parts=[types.Part(text="live hello")]
),
)


class MockLlmAgent(LlmAgent):
"""Mock LLM agent for unit testing."""

Expand Down Expand Up @@ -237,6 +255,109 @@ def _infer_agent_origin(
assert "Ensure the runner app_name matches" in message


@pytest.mark.asyncio
async def test_session_auto_creation():

class RunnerWithMismatch(Runner):

def _infer_agent_origin(
self, agent: BaseAgent
) -> tuple[Optional[str], Optional[Path]]:
del agent
return "expected_app", Path("/workspace/agents/expected_app")

session_service = InMemorySessionService()
runner = RunnerWithMismatch(
app_name="expected_app",
agent=MockLlmAgent("test_agent"),
session_service=session_service,
artifact_service=InMemoryArtifactService(),
auto_create_session=True,
)

agen = runner.run_async(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Preference to include tests for rewind_async and run_live as well. Thanks!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The requested changes have been addressed. Please note that I decided to keep the auto_create_session flag outside of the helper function to ensure better code readability. New test cases have also been included. I appreciate your review.

user_id="user",
session_id="missing",
new_message=types.Content(role="user", parts=[types.Part(text="hi")]),
)

event = await agen.__anext__()
await agen.aclose()

# Verify that session_id="missing" doesn't error out - session is auto-created
assert event.author == "test_agent"
assert event.content.parts[0].text == "Test LLM response"


@pytest.mark.asyncio
async def test_rewind_auto_create_session_on_missing_session():
"""When auto_create_session=True, rewind should create session if missing.

The newly created session won't contain the target invocation, so
`rewind_async` should raise an Invocation ID not found error (rather than
a session not found error), demonstrating auto-creation occurred.
"""
session_service = InMemorySessionService()
runner = Runner(
app_name="auto_create_app",
agent=MockLlmAgent("agent_for_rewind"),
session_service=session_service,
artifact_service=InMemoryArtifactService(),
auto_create_session=True,
)

with pytest.raises(ValueError, match=r"Invocation ID not found: inv_missing"):
await runner.rewind_async(
user_id="user",
session_id="missing",
rewind_before_invocation_id="inv_missing",
)

# Verify the session actually exists now due to auto-creation.
session = await session_service.get_session(
app_name="auto_create_app", user_id="user", session_id="missing"
)
assert session is not None
assert session.app_name == "auto_create_app"


@pytest.mark.asyncio
async def test_run_live_auto_create_session():
"""run_live should auto-create session when missing and yield events."""
session_service = InMemorySessionService()
artifact_service = InMemoryArtifactService()
runner = Runner(
app_name="live_app",
agent=MockLiveAgent("live_agent"),
session_service=session_service,
artifact_service=artifact_service,
auto_create_session=True,
)

# An empty LiveRequestQueue is sufficient for our mock agent.
from google.adk.agents.live_request_queue import LiveRequestQueue

live_queue = LiveRequestQueue()

agen = runner.run_live(
user_id="user",
session_id="missing",
live_request_queue=live_queue,
)

event = await agen.__anext__()
await agen.aclose()

assert event.author == "live_agent"
assert event.content.parts[0].text == "live hello"

# Session should have been created automatically.
session = await session_service.get_session(
app_name="live_app", user_id="user", session_id="missing"
)
assert session is not None


@pytest.mark.asyncio
async def test_runner_allows_nested_agent_directories(tmp_path, monkeypatch):
project_root = tmp_path / "workspace"
Expand Down