Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions components/runners/ambient-runner/ag_ui_claude_sdk/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def build_options(
self,
input_data: Optional[RunAgentInput] = None,
thread_id: Optional[str] = None,
resume_from: Optional[str] = None,
) -> "ClaudeAgentOptions":
"""
Build ClaudeAgentOptions from stored options (object/dict/None) plus dynamic tools.
Expand All @@ -378,6 +379,8 @@ def build_options(
Args:
input_data: Optional RunAgentInput for extracting dynamic tools
thread_id: Optional thread_id for session resumption lookup
resume_from: Optional CLI session ID to resume (preserves chat history
across adapter rebuilds, e.g. after a repo is added mid-session)

Returns:
Configured ClaudeAgentOptions instance
Expand Down Expand Up @@ -451,6 +454,11 @@ def build_options(

# Remove api_key from options kwargs (handled via environment variable)
merged_kwargs.pop("api_key", None)

# Resume from a previous CLI session (preserves chat context)
if resume_from:
merged_kwargs["resume"] = resume_from

logger.debug(f"Merged kwargs after pop: {merged_kwargs}")

# Apply forwarded_props as per-run overrides (before adding dynamic tools)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def __init__(self) -> None:
self._allowed_tools: list[str] = []
self._system_prompt: dict = {}
self._stderr_lines: list[str] = []
# Preserved session IDs across adapter rebuilds (e.g. repo additions)
self._saved_session_ids: dict[str, str] = {}

# ------------------------------------------------------------------
# PlatformBridge interface
Expand Down Expand Up @@ -99,7 +101,13 @@ async def run(self, input_data: RunAgentInput) -> AsyncIterator[BaseEvent]:
# 4. Get or create session worker for this thread
thread_id = input_data.thread_id or self._context.session_id
api_key = os.getenv("ANTHROPIC_API_KEY", "")
sdk_options = self._adapter.build_options(input_data, thread_id=thread_id)
saved_session_id = (
self._saved_session_ids.pop(thread_id, None)
or self._session_manager.get_session_id(thread_id)
)
sdk_options = self._adapter.build_options(
input_data, thread_id=thread_id, resume_from=saved_session_id
)
worker = await self._session_manager.get_or_create(
thread_id, sdk_options, api_key
)
Expand All @@ -121,6 +129,11 @@ async def run(self, input_data: RunAgentInput) -> AsyncIterator[BaseEvent]:
async for event in wrapped_stream:
yield event

# Persist session ID after turn completes (for --resume on pod restart)
if worker.session_id:
self._session_manager._session_ids[thread_id] = worker.session_id
self._session_manager._persist_session_ids()

Comment on lines +132 to +136
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Direct access to private _session_ids and _persist_session_ids() breaks encapsulation.

The bridge manipulates SessionManager internals directly. Consider adding a public method like update_session_id(thread_id, session_id) that handles both assignment and persistence:

♻️ Suggested approach

In session.py, add:

def update_session_id(self, thread_id: str, session_id: str) -> None:
    """Update and persist session ID for a thread."""
    self._session_ids[thread_id] = session_id
    self._persist_session_ids()

Then in bridge.py:

-            if worker.session_id:
-                self._session_manager._session_ids[thread_id] = worker.session_id
-                self._session_manager._persist_session_ids()
+            if worker.session_id:
+                self._session_manager.update_session_id(thread_id, worker.session_id)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@components/runners/ambient-runner/ambient_runner/bridges/claude/bridge.py`
around lines 132 - 136, The bridge is directly manipulating SessionManager
internals (_session_ids and _persist_session_ids), breaking encapsulation;
modify SessionManager (class SessionManager) to add a public method like
update_session_id(thread_id, session_id) that sets the mapping and calls
persistence, then replace direct accesses in bridge.py (where worker.session_id
and thread_id are used) to call session_manager.update_session_id(thread_id,
worker.session_id) instead of touching _session_ids or _persist_session_ids
directly.

self._first_run = False

async def interrupt(self, thread_id: Optional[str] = None) -> None:
Expand Down Expand Up @@ -167,6 +180,9 @@ def mark_dirty(self) -> None:
self._first_run = True
self._adapter = None
if self._session_manager:
# Preserve session IDs so --resume works after adapter rebuild.
# Must be captured synchronously before the async shutdown task runs.
self._saved_session_ids.update(self._session_manager.get_all_session_ids())
manager = self._session_manager
self._session_manager = None
_async_safe_manager_shutdown(manager)
Expand Down Expand Up @@ -279,7 +295,11 @@ async def _setup_platform(self) -> None:
"""Full platform setup: auth, workspace, MCP, observability."""
# Session manager
if self._session_manager is None:
self._session_manager = SessionManager()
state_dir = os.path.join(
os.getenv("WORKSPACE_PATH", "/workspace"),
os.getenv("RUNNER_STATE_DIR", ".claude"),
)
self._session_manager = SessionManager(state_dir=state_dir)

# Claude-specific auth
from ambient_runner.bridges.claude.auth import setup_sdk_authentication
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
"""

import asyncio
import json
import logging
import os
from contextlib import suppress
from pathlib import Path
from typing import Any, AsyncIterator, Optional

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -100,7 +102,11 @@ async def _run(self) -> None:

os.environ["ANTHROPIC_API_KEY"] = self._api_key

from ambient_runner.bridges.claude.mock_client import MOCK_API_KEY, MockClaudeSDKClient
from ambient_runner.bridges.claude.mock_client import (
MOCK_API_KEY,
MockClaudeSDKClient,
)

if self._api_key == MOCK_API_KEY:
logger.info("[SessionWorker] Using MockClaudeSDKClient (replay mode)")
client: Any = MockClaudeSDKClient(options=self._options)
Expand Down Expand Up @@ -248,13 +254,18 @@ class SessionManager:
mix messages on the single underlying SDK client).

Tracks session IDs returned by the CLI so that workers can be recreated
with ``--resume`` after a pod restart.
with ``--resume`` after a pod restart. Session IDs are persisted to disk
so they survive pod restarts.
"""

def __init__(self) -> None:
_SESSION_IDS_FILE = "claude_session_ids.json"

def __init__(self, state_dir: str = "") -> None:
self._workers: dict[str, SessionWorker] = {}
self._locks: dict[str, asyncio.Lock] = {}
self._session_ids: dict[str, str] = {} # thread_id -> CLI session_id
self._state_dir = state_dir
self._restore_session_ids()

async def get_or_create(
self,
Expand Down Expand Up @@ -302,6 +313,14 @@ def get_session_id(self, thread_id: str) -> Optional[str]:
return worker.session_id
return self._session_ids.get(thread_id)

def get_all_session_ids(self) -> dict[str, str]:
"""Return a snapshot of all known session IDs (live workers + cached)."""
result = dict(self._session_ids)
for tid, worker in self._workers.items():
if worker.session_id:
result[tid] = worker.session_id
return result

async def destroy(self, thread_id: str) -> None:
"""Stop and remove the worker for *thread_id*.

Expand All @@ -312,6 +331,7 @@ async def destroy(self, thread_id: str) -> None:
if worker is not None:
if worker.session_id:
self._session_ids[thread_id] = worker.session_id
self._persist_session_ids()
await worker.stop()
self._locks.pop(thread_id, None)
logger.debug(f"[SessionManager] Destroyed worker for thread={thread_id}")
Expand All @@ -322,3 +342,39 @@ async def shutdown(self) -> None:
for tid in thread_ids:
await self.destroy(tid)
logger.info("[SessionManager] All workers shut down")

# ── session ID persistence ──

def _session_ids_path(self) -> Path | None:
if not self._state_dir:
return None
return Path(self._state_dir) / self._SESSION_IDS_FILE

def _persist_session_ids(self) -> None:
"""Save session IDs to disk for --resume across pod restarts."""
path = self._session_ids_path()
if not path or not self._session_ids:
return
try:
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
json.dump(self._session_ids, f)
logger.info("Persisted %d session ID(s) to %s", len(self._session_ids), path)
except OSError:
logger.debug("Could not persist session IDs to %s", path, exc_info=True)

def _restore_session_ids(self) -> None:
"""Restore session IDs from disk (written by a previous pod)."""
path = self._session_ids_path()
if not path or not path.exists():
return
try:
with open(path) as f:
restored = json.load(f)
if isinstance(restored, dict):
self._session_ids.update(restored)
logger.info(
"Restored %d Claude session ID(s) from %s", len(restored), path
)
except (OSError, json.JSONDecodeError):
logger.debug("Could not restore session IDs from %s", path, exc_info=True)
Loading