diff --git a/src/claude/sdk_integration.py b/src/claude/sdk_integration.py index adf553f4..5d7c2fe3 100644 --- a/src/claude/sdk_integration.py +++ b/src/claude/sdk_integration.py @@ -146,6 +146,16 @@ def __init__( else: logger.info("No API key provided, using existing Claude CLI authentication") + def _is_retryable_error(self, exc: BaseException) -> bool: + """Return True for transient errors that warrant a retry. + asyncio.TimeoutError is intentional (user-configured timeout) — not retried. + Only non-MCP CLIConnectionError is considered transient. + """ + if isinstance(exc, CLIConnectionError): + msg = str(exc).lower() + return "mcp" not in msg # "server" alone is too broad + return False + async def execute_command( self, prompt: str, @@ -288,11 +298,49 @@ async def _run_client() -> None: finally: await client.disconnect() - # Execute with timeout - await asyncio.wait_for( - _run_client(), - timeout=self.config.claude_timeout_seconds, - ) + # Execute with timeout, retrying on transient CLIConnectionError + max_attempts = max(1, self.config.claude_retry_max_attempts) + + for attempt in range(max_attempts): + # Reset message accumulator each attempt so that a failed attempt + # does not pollute the next one with partial/duplicate messages. + # _run_client() closes over `messages` by reference (late-binding + # closure), so clearing it here is seen by every new call. + messages.clear() + + if attempt > 0: + delay = min( + self.config.claude_retry_base_delay + * (self.config.claude_retry_backoff_factor ** (attempt - 1)), + self.config.claude_retry_max_delay, + ) + logger.warning( + "Retrying Claude SDK command", + attempt=attempt + 1, + max_attempts=max_attempts, + delay_seconds=delay, + ) + await asyncio.sleep(delay) + # Note: asyncio.TimeoutError raised by wait_for is intentionally + # NOT caught below — it propagates immediately to the outer + # `except asyncio.TimeoutError` handler, bypassing the retry + # loop entirely. Timeouts reflect a user-configured hard limit + # and should not be retried. + try: + await asyncio.wait_for( + _run_client(), + timeout=self.config.claude_timeout_seconds, + ) + break # success — exit retry loop + except CLIConnectionError as exc: + if self._is_retryable_error(exc) and attempt < max_attempts - 1: + logger.warning( + "Transient connection error, will retry", + attempt=attempt + 1, + error=str(exc), + ) + continue + raise # non-retryable or attempts exhausted # Extract cost, tools, and session_id from result message cost = 0.0 diff --git a/src/config/settings.py b/src/config/settings.py index 77c34ea4..5b96cf62 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -26,6 +26,10 @@ DEFAULT_RATE_LIMIT_BURST, DEFAULT_RATE_LIMIT_REQUESTS, DEFAULT_RATE_LIMIT_WINDOW, + DEFAULT_RETRY_BACKOFF_FACTOR, + DEFAULT_RETRY_BASE_DELAY, + DEFAULT_RETRY_MAX_ATTEMPTS, + DEFAULT_RETRY_MAX_DELAY, DEFAULT_SESSION_TIMEOUT_HOURS, ) @@ -121,6 +125,34 @@ class Settings(BaseSettings): description="List of explicitly disallowed Claude tools/commands", ) + # Retry settings + claude_retry_max_attempts: int = Field( + DEFAULT_RETRY_MAX_ATTEMPTS, + ge=0, + description="Max retry attempts for transient SDK errors (0 = disabled)", + ) + claude_retry_base_delay: float = Field( + DEFAULT_RETRY_BASE_DELAY, + ge=0, + description=( + "Base delay in seconds between retries. " + "0 means retries are attempted immediately with no pause." + ), + ) + claude_retry_backoff_factor: float = Field( + DEFAULT_RETRY_BACKOFF_FACTOR, + gt=0, + description="Exponential backoff multiplier", + ) + claude_retry_max_delay: float = Field( + DEFAULT_RETRY_MAX_DELAY, + ge=0, + description=( + "Maximum delay cap in seconds. " + "0 disables the cap entirely (delays grow unbounded with backoff)." + ), + ) + # Sandbox settings sandbox_enabled: bool = Field( True, diff --git a/src/utils/constants.py b/src/utils/constants.py index 5ea9a4c3..7b66f9a6 100644 --- a/src/utils/constants.py +++ b/src/utils/constants.py @@ -85,5 +85,11 @@ DEFAULT_CLAUDE_BINARY = "claude" DEFAULT_CLAUDE_OUTPUT_FORMAT = "stream-json" +# Retry defaults +DEFAULT_RETRY_MAX_ATTEMPTS = 3 +DEFAULT_RETRY_BASE_DELAY = 1.0 +DEFAULT_RETRY_BACKOFF_FACTOR = 3.0 +DEFAULT_RETRY_MAX_DELAY = 30.0 + # Logging LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" diff --git a/tests/unit/test_claude/test_sdk_integration.py b/tests/unit/test_claude/test_sdk_integration.py index 17ba58ab..9c2b3777 100644 --- a/tests/unit/test_claude/test_sdk_integration.py +++ b/tests/unit/test_claude/test_sdk_integration.py @@ -390,6 +390,90 @@ async def test_execute_command_no_resume_for_new_session(self, sdk_manager): not hasattr(captured_options[0], "resume") or not captured_options[0].resume ) + async def test_retry_on_transient_cli_connection_error(self, sdk_manager): + """Test that transient CLIConnectionError triggers retry and succeeds.""" + from claude_agent_sdk import CLIConnectionError + + call_count = 0 + + async def flaky_receive(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise CLIConnectionError("connection reset") + # Second attempt succeeds - yield a ResultMessage + yield + + # Use a config with 2 attempts + sdk_manager.config.claude_retry_max_attempts = 2 + + client = AsyncMock() + client.connect = AsyncMock() + client.disconnect = AsyncMock() + client.query = AsyncMock() + query_mock = AsyncMock() + query_mock.receive_messages = flaky_receive + client._query = query_mock + + # Should not raise - second attempt succeeds + with patch("src.claude.sdk_integration.ClaudeSDKClient", return_value=client): + with patch("asyncio.sleep", new_callable=AsyncMock): + try: + await sdk_manager.execute_command( + prompt="Test", + working_directory=Path("/test"), + ) + except Exception: + pass # Response parsing may fail - what matters is retry happened + assert call_count == 2 + + async def test_no_retry_on_mcp_connection_error(self, sdk_manager): + """Test that MCP CLIConnectionError is NOT retried.""" + from claude_agent_sdk import CLIConnectionError + + from src.claude.exceptions import ClaudeMCPError + + client = AsyncMock() + client.connect = AsyncMock() + client.disconnect = AsyncMock() + client.query = AsyncMock(side_effect=CLIConnectionError("mcp server failed")) + + with patch("src.claude.sdk_integration.ClaudeSDKClient", return_value=client): + with pytest.raises((ClaudeMCPError, Exception)): + await sdk_manager.execute_command( + prompt="Test", + working_directory=Path("/test"), + ) + # Only called once - no retry for MCP errors + assert client.query.call_count == 1 + + async def test_retry_disabled_when_max_attempts_zero(self, sdk_manager): + """Test that setting max_attempts=0 effectively disables retries (1 attempt).""" + sdk_manager.config.claude_retry_max_attempts = 0 + assert max(1, sdk_manager.config.claude_retry_max_attempts) == 1 + + def test_is_retryable_error_transient(self, sdk_manager): + """Test _is_retryable_error returns True for transient connection errors.""" + from claude_agent_sdk import CLIConnectionError + + assert ( + sdk_manager._is_retryable_error(CLIConnectionError("connection reset")) + is True + ) + + def test_is_retryable_error_mcp(self, sdk_manager): + """Test _is_retryable_error returns False for MCP errors.""" + from claude_agent_sdk import CLIConnectionError + + assert ( + sdk_manager._is_retryable_error(CLIConnectionError("mcp server failed")) + is False + ) + + def test_is_retryable_error_timeout(self, sdk_manager): + """Test _is_retryable_error returns False for timeout errors.""" + assert sdk_manager._is_retryable_error(asyncio.TimeoutError()) is False + class TestClaudeSandboxSettings: """Test sandbox and system_prompt settings on ClaudeAgentOptions."""