Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions python/packages/jumpstarter/jumpstarter/client/lease.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,9 @@ async def _acquire(self):
# Old controllers (pre-918d6341) mark offline-but-matching
# exporters as Unsatisfiable with reason "NoExporter".
# This is transient — retry with a new lease.
if condition_present_and_equal(result.conditions, "Unsatisfiable", "True", "NoExporter"):
if condition_present_and_equal(
result.conditions, "Unsatisfiable", "True", "NoExporter"
):
await self._handle_no_exporter_retry(spinner, message)
continue
logger.debug("Lease %s cannot be satisfied: %s", self.name, message)
Expand Down Expand Up @@ -330,16 +332,29 @@ async def handle_async(self, stream):
if remaining <= 0:
logger.debug(
"Exporter not ready and dial timeout (%.1fs) exceeded after %d attempts",
self.dial_timeout,
attempt + 1,
self.dial_timeout, attempt + 1
)
raise
delay = min(base_delay * (2**attempt), max_delay, remaining)
delay = min(base_delay * (2 ** attempt), max_delay, remaining)
logger.debug(
"Exporter not ready, retrying Dial in %.1fs (attempt %d, %.1fs remaining)",
delay,
attempt + 1,
remaining,
delay, attempt + 1, remaining
)
await sleep(delay)
attempt += 1
continue
if e.code() == grpc.StatusCode.UNAVAILABLE:
remaining = deadline - time.monotonic()
if remaining <= 0:
logger.warning(
"Exporter unavailable and dial timeout (%.1fs) exceeded after %d attempts",
self.dial_timeout, attempt + 1
)
raise
delay = min(base_delay * (2 ** attempt), max_delay, remaining)
logger.warning(
"Exporter unavailable, retrying Dial in %.1fs (attempt %d, %.1fs remaining)",
delay, attempt + 1, remaining
)
await sleep(delay)
attempt += 1
Expand All @@ -348,7 +363,8 @@ async def handle_async(self, stream):
if "permission denied" in str(e.details()).lower():
self.lease_transferred = True
logger.warning(
"Lease %s has been transferred to another client. Your session is no longer valid.",
"Lease %s has been transferred to another client. "
"Your session is no longer valid.",
self.name,
)
else:
Expand Down
76 changes: 76 additions & 0 deletions python/packages/jumpstarter/jumpstarter/client/lease_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,29 @@
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, Mock, patch

import grpc
import pytest
from grpc.aio import AioRpcError
from rich.console import Console

from jumpstarter.client.exceptions import LeaseError
from jumpstarter.client.lease import Lease, LeaseAcquisitionSpinner


class MockAioRpcError(AioRpcError):
"""Mock gRPC error for testing that properly inherits from AioRpcError."""

def __init__(self, status_code, message=""):
self._status_code = status_code
self._message = message

def code(self):
return self._status_code

def details(self):
return self._message


class TestLeaseAcquisitionSpinner:
"""Test cases for LeaseAcquisitionSpinner class."""

Expand Down Expand Up @@ -554,3 +570,63 @@ async def get_then_fail():
callback.assert_called()
_, remain_arg = callback.call_args[0]
assert remain_arg == timedelta(0)


class TestHandleAsyncUnavailableRetry:
"""Tests for Lease.handle_async UNAVAILABLE retry behavior."""

def _make_lease_for_handle(self):
lease = object.__new__(Lease)
lease.name = "test-lease"
lease.dial_timeout = 5.0
lease.lease_transferred = False
lease.tls_config = Mock()
lease.grpc_options = {}
lease.controller = Mock()
return lease

@pytest.mark.anyio
async def test_handle_async_retries_unavailable_then_succeeds(self):
"""Dial returns UNAVAILABLE once then succeeds on retry."""
lease = self._make_lease_for_handle()
dial_call_count = 0

async def mock_dial(request):
nonlocal dial_call_count
dial_call_count += 1
if dial_call_count == 1:
raise MockAioRpcError(grpc.StatusCode.UNAVAILABLE, "temporarily unavailable")
return Mock(router_endpoint="endpoint", router_token="token")

lease.controller.Dial = mock_dial

with patch("jumpstarter.client.lease.connect_router_stream") as mock_connect:
mock_connect.return_value.__aenter__ = AsyncMock()
mock_connect.return_value.__aexit__ = AsyncMock(return_value=False)
stream = Mock()

await lease.handle_async(stream)

assert dial_call_count == 2
mock_connect.assert_called_once_with("endpoint", "token", stream, lease.tls_config, lease.grpc_options)

@pytest.mark.anyio
async def test_handle_async_unavailable_exceeds_dial_timeout(self):
"""Dial returns UNAVAILABLE until dial_timeout is exceeded, then raises."""
lease = self._make_lease_for_handle()
lease.dial_timeout = 0.5
dial_call_count = 0

async def mock_dial(request):
nonlocal dial_call_count
dial_call_count += 1
raise MockAioRpcError(grpc.StatusCode.UNAVAILABLE, "permanently unavailable")

lease.controller.Dial = mock_dial
stream = Mock()

with pytest.raises(AioRpcError) as exc_info:
await lease.handle_async(stream)

assert exc_info.value.code() == grpc.StatusCode.UNAVAILABLE
assert dial_call_count >= 2
26 changes: 18 additions & 8 deletions python/packages/jumpstarter/jumpstarter/client/status_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,8 @@ async def _poll_loop(self): # noqa: C901
return

deadline_retries = 0
unavailable_retries = 0
unavailable_max_retries = 10

Comment on lines +325 to 327
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Reset unavailable_retries on non-UNAVAILABLE errors to keep the threshold truly consecutive.

At Line 394, unavailable_retries increments correctly, but it is only reset on successful polls (Line 348). If DEADLINE_EXCEEDED (or another RPC error) occurs between UNAVAILABLEs, the counter still carries over, so connection loss can be triggered without 10 consecutive UNAVAILABLEs.

Suggested patch
             except AioRpcError as e:
                 if e.code() == StatusCode.UNIMPLEMENTED:
                     logger.debug("GetStatus not implemented (server), assuming LEASE_READY")
                     self._signal_unsupported()
                     break
                 elif e.code() == StatusCode.UNAVAILABLE:
                     unavailable_retries += 1
                     if unavailable_retries >= unavailable_max_retries:
                         logger.warning(
                             "GetStatus UNAVAILABLE %d times consecutively, marking connection as lost",
                             unavailable_retries,
                         )
                         self._connection_lost = True
                         self._running = False
                         self._any_change_event.set()
                         self._any_change_event = Event()
                         break
                     elif unavailable_retries % 5 == 0:
                         logger.warning("GetStatus UNAVAILABLE %d times consecutively", unavailable_retries)
                     else:
                         logger.debug("GetStatus UNAVAILABLE (attempt %d), retrying...", unavailable_retries)
                 elif e.code() == StatusCode.DEADLINE_EXCEEDED:
+                    unavailable_retries = 0
                     # DEADLINE_EXCEEDED is a transient error (RPC timed out), not a
                     # permanent connection loss. Keep polling - the shell's own timeout
                     # on wait_for_any_of is the real deadline. Only UNAVAILABLE indicates
                     # a true connection loss (server down/disconnected).
                     deadline_retries += 1
                     if deadline_retries >= 20:
@@
                     else:
                         logger.debug("GetStatus timed out (attempt %d), retrying...", deadline_retries)
                     continue
+                else:
+                    unavailable_retries = 0
                 logger.debug(f"GetStatus poll error: {e.code()}")

Also applies to: 348-348, 394-409

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@python/packages/jumpstarter/jumpstarter/client/status_monitor.py` around
lines 325 - 327, The unavailable_retries counter must only count consecutive
UNAVAILABLE errors: in the status polling function (where unavailable_retries
and unavailable_max_retries are defined and where unavailable_retries is
incremented at the UNAVAILABLE branch) change the logic so that
unavailable_retries is incremented only when the RPC/status is UNAVAILABLE and
is explicitly reset to 0 for any other outcome (successful poll,
DEADLINE_EXCEEDED, other RPC errors, or exceptions). Locate the block that
inspects the RPC status (the place that currently increments unavailable_retries
at UNAVAILABLE and only resets on success) and add a branch to set
unavailable_retries = 0 whenever the status is not UNAVAILABLE so the threshold
truly requires consecutive UNAVAILABLEs.

while self._running:
try:
Expand All @@ -343,6 +345,7 @@ async def _poll_loop(self): # noqa: C901
logger.info("Connection recovered, resetting connection_lost flag")
self._connection_lost = False
deadline_retries = 0
unavailable_retries = 0

# Detect missed transitions
if self._status_version > 0 and new_version > self._status_version + 1:
Expand Down Expand Up @@ -388,14 +391,21 @@ async def _poll_loop(self): # noqa: C901
self._signal_unsupported()
break
elif e.code() == StatusCode.UNAVAILABLE:
# Connection lost - exporter closed or restarted
logger.info("Connection lost (UNAVAILABLE), signaling waiters")
self._connection_lost = True
self._running = False
# Fire the change event to wake up any waiters
self._any_change_event.set()
self._any_change_event = Event()
break
unavailable_retries += 1
if unavailable_retries >= unavailable_max_retries:
logger.warning(
"GetStatus UNAVAILABLE %d times consecutively, marking connection as lost",
unavailable_retries,
)
self._connection_lost = True
self._running = False
self._any_change_event.set()
self._any_change_event = Event()
break
elif unavailable_retries % 5 == 0:
logger.warning("GetStatus UNAVAILABLE %d times consecutively", unavailable_retries)
else:
logger.debug("GetStatus UNAVAILABLE (attempt %d), retrying...", unavailable_retries)
elif e.code() == StatusCode.DEADLINE_EXCEEDED:
# DEADLINE_EXCEEDED is a transient error (RPC timed out), not a
# permanent connection loss. Keep polling - the shell's own timeout
Expand Down
Loading
Loading