diff --git a/components/runners/ambient-runner/ag_ui_claude_sdk/adapter.py b/components/runners/ambient-runner/ag_ui_claude_sdk/adapter.py index 127729a52..2e0bce56e 100644 --- a/components/runners/ambient-runner/ag_ui_claude_sdk/adapter.py +++ b/components/runners/ambient-runner/ag_ui_claude_sdk/adapter.py @@ -369,6 +369,7 @@ def build_options( self, input_data: Optional[RunAgentInput] = None, thread_id: Optional[str] = None, + resume_from: Optional[str] = None, ) -> "ClaudeAgentOptions": """ Build ClaudeAgentOptions from stored options (object/dict/None) plus dynamic tools. @@ -378,6 +379,8 @@ def build_options( Args: input_data: Optional RunAgentInput for extracting dynamic tools thread_id: Optional thread_id for session resumption lookup + resume_from: Optional CLI session ID to resume (preserves chat history + across adapter rebuilds, e.g. after a repo is added mid-session) Returns: Configured ClaudeAgentOptions instance @@ -451,6 +454,11 @@ def build_options( # Remove api_key from options kwargs (handled via environment variable) merged_kwargs.pop("api_key", None) + + # Resume from a previous CLI session (preserves chat context) + if resume_from: + merged_kwargs["resume"] = resume_from + logger.debug(f"Merged kwargs after pop: {merged_kwargs}") # Apply forwarded_props as per-run overrides (before adding dynamic tools) diff --git a/components/runners/ambient-runner/ambient_runner/bridges/claude/bridge.py b/components/runners/ambient-runner/ambient_runner/bridges/claude/bridge.py index ec77ed0ae..dfd02f88a 100644 --- a/components/runners/ambient-runner/ambient_runner/bridges/claude/bridge.py +++ b/components/runners/ambient-runner/ambient_runner/bridges/claude/bridge.py @@ -56,6 +56,8 @@ def __init__(self) -> None: self._allowed_tools: list[str] = [] self._system_prompt: dict = {} self._stderr_lines: list[str] = [] + # Preserved session IDs across adapter rebuilds (e.g. repo additions) + self._saved_session_ids: dict[str, str] = {} # ------------------------------------------------------------------ # PlatformBridge interface @@ -99,7 +101,13 @@ async def run(self, input_data: RunAgentInput) -> AsyncIterator[BaseEvent]: # 4. Get or create session worker for this thread thread_id = input_data.thread_id or self._context.session_id api_key = os.getenv("ANTHROPIC_API_KEY", "") - sdk_options = self._adapter.build_options(input_data, thread_id=thread_id) + saved_session_id = ( + self._saved_session_ids.pop(thread_id, None) + or self._session_manager.get_session_id(thread_id) + ) + sdk_options = self._adapter.build_options( + input_data, thread_id=thread_id, resume_from=saved_session_id + ) worker = await self._session_manager.get_or_create( thread_id, sdk_options, api_key ) @@ -121,6 +129,11 @@ async def run(self, input_data: RunAgentInput) -> AsyncIterator[BaseEvent]: async for event in wrapped_stream: yield event + # Persist session ID after turn completes (for --resume on pod restart) + if worker.session_id: + self._session_manager._session_ids[thread_id] = worker.session_id + self._session_manager._persist_session_ids() + self._first_run = False async def interrupt(self, thread_id: Optional[str] = None) -> None: @@ -167,6 +180,9 @@ def mark_dirty(self) -> None: self._first_run = True self._adapter = None if self._session_manager: + # Preserve session IDs so --resume works after adapter rebuild. + # Must be captured synchronously before the async shutdown task runs. + self._saved_session_ids.update(self._session_manager.get_all_session_ids()) manager = self._session_manager self._session_manager = None _async_safe_manager_shutdown(manager) @@ -279,7 +295,11 @@ async def _setup_platform(self) -> None: """Full platform setup: auth, workspace, MCP, observability.""" # Session manager if self._session_manager is None: - self._session_manager = SessionManager() + state_dir = os.path.join( + os.getenv("WORKSPACE_PATH", "/workspace"), + os.getenv("RUNNER_STATE_DIR", ".claude"), + ) + self._session_manager = SessionManager(state_dir=state_dir) # Claude-specific auth from ambient_runner.bridges.claude.auth import setup_sdk_authentication diff --git a/components/runners/ambient-runner/ambient_runner/bridges/claude/session.py b/components/runners/ambient-runner/ambient_runner/bridges/claude/session.py index 1e581f191..158f16e40 100644 --- a/components/runners/ambient-runner/ambient_runner/bridges/claude/session.py +++ b/components/runners/ambient-runner/ambient_runner/bridges/claude/session.py @@ -26,9 +26,11 @@ """ import asyncio +import json import logging import os from contextlib import suppress +from pathlib import Path from typing import Any, AsyncIterator, Optional logger = logging.getLogger(__name__) @@ -100,7 +102,11 @@ async def _run(self) -> None: os.environ["ANTHROPIC_API_KEY"] = self._api_key - from ambient_runner.bridges.claude.mock_client import MOCK_API_KEY, MockClaudeSDKClient + from ambient_runner.bridges.claude.mock_client import ( + MOCK_API_KEY, + MockClaudeSDKClient, + ) + if self._api_key == MOCK_API_KEY: logger.info("[SessionWorker] Using MockClaudeSDKClient (replay mode)") client: Any = MockClaudeSDKClient(options=self._options) @@ -248,13 +254,18 @@ class SessionManager: mix messages on the single underlying SDK client). Tracks session IDs returned by the CLI so that workers can be recreated - with ``--resume`` after a pod restart. + with ``--resume`` after a pod restart. Session IDs are persisted to disk + so they survive pod restarts. """ - def __init__(self) -> None: + _SESSION_IDS_FILE = "claude_session_ids.json" + + def __init__(self, state_dir: str = "") -> None: self._workers: dict[str, SessionWorker] = {} self._locks: dict[str, asyncio.Lock] = {} self._session_ids: dict[str, str] = {} # thread_id -> CLI session_id + self._state_dir = state_dir + self._restore_session_ids() async def get_or_create( self, @@ -302,6 +313,14 @@ def get_session_id(self, thread_id: str) -> Optional[str]: return worker.session_id return self._session_ids.get(thread_id) + def get_all_session_ids(self) -> dict[str, str]: + """Return a snapshot of all known session IDs (live workers + cached).""" + result = dict(self._session_ids) + for tid, worker in self._workers.items(): + if worker.session_id: + result[tid] = worker.session_id + return result + async def destroy(self, thread_id: str) -> None: """Stop and remove the worker for *thread_id*. @@ -312,6 +331,7 @@ async def destroy(self, thread_id: str) -> None: if worker is not None: if worker.session_id: self._session_ids[thread_id] = worker.session_id + self._persist_session_ids() await worker.stop() self._locks.pop(thread_id, None) logger.debug(f"[SessionManager] Destroyed worker for thread={thread_id}") @@ -322,3 +342,39 @@ async def shutdown(self) -> None: for tid in thread_ids: await self.destroy(tid) logger.info("[SessionManager] All workers shut down") + + # ── session ID persistence ── + + def _session_ids_path(self) -> Path | None: + if not self._state_dir: + return None + return Path(self._state_dir) / self._SESSION_IDS_FILE + + def _persist_session_ids(self) -> None: + """Save session IDs to disk for --resume across pod restarts.""" + path = self._session_ids_path() + if not path or not self._session_ids: + return + try: + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w") as f: + json.dump(self._session_ids, f) + logger.info("Persisted %d session ID(s) to %s", len(self._session_ids), path) + except OSError: + logger.debug("Could not persist session IDs to %s", path, exc_info=True) + + def _restore_session_ids(self) -> None: + """Restore session IDs from disk (written by a previous pod).""" + path = self._session_ids_path() + if not path or not path.exists(): + return + try: + with open(path) as f: + restored = json.load(f) + if isinstance(restored, dict): + self._session_ids.update(restored) + logger.info( + "Restored %d Claude session ID(s) from %s", len(restored), path + ) + except (OSError, json.JSONDecodeError): + logger.debug("Could not restore session IDs from %s", path, exc_info=True)