Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import types
from functools import wraps
from types import TracebackType
from typing import NoReturn

import click

Expand Down Expand Up @@ -205,45 +204,62 @@


def _handle_connection_error_with_reauth(exc, login_func):
"""Handle ConnectionError with reauthentication logic."""
"""Handle ConnectionError with reauthentication logic.

Returns True if re-auth succeeded and the caller should retry.
Raises if re-auth is not applicable or fails.
"""
if "expired" in str(exc).lower():
click.echo(click.style("Token is expired, triggering re-authentication", fg="red"))
config = exc.get_config()
login_func(config)
raise ClickExceptionRed("Please try again now") from None
return True
else:
raise ClickExceptionRed(str(exc)) from None


def _handle_single_exception_with_reauth(exc, login_func):
"""Handle a single exception (may raise)."""
"""Handle a single exception (may raise).

Returns True if re-auth succeeded and the caller should retry.
"""
if isinstance(exc, ConnectionError):
_handle_connection_error_with_reauth(exc, login_func)
return _handle_connection_error_with_reauth(exc, login_func)
elif cli_exc := _map_cli_exception(exc):
raise cli_exc from None
# Not handled: fall through
return False


def _handle_exception_group_with_reauth(eg, login_func):
"""Handle exceptions wrapped in BaseExceptionGroup.

def _handle_exception_group_with_reauth(eg, login_func) -> NoReturn:
"""Handle exceptions wrapped in BaseExceptionGroup."""
Returns True if re-auth succeeded and the caller should retry.
"""
for exc in leaf_exceptions(eg, fix_tracebacks=False):
_handle_single_exception_with_reauth(exc, login_func)
if _handle_single_exception_with_reauth(exc, login_func):
return True
Comment thread
coderabbitai[bot] marked this conversation as resolved.
# If no handled exceptions, re-raise the original group
raise eg


def handle_exceptions_with_reauthentication(login_func):

Check failure on line 244 in python/packages/jumpstarter-cli-common/jumpstarter_cli_common/exceptions.py

View workflow job for this annotation

GitHub Actions / lint-python

ruff (C901)

python/packages/jumpstarter-cli-common/jumpstarter_cli_common/exceptions.py:244:5: C901 `handle_exceptions_with_reauthentication` is too complex (16 > 10)
"""Decorator to handle exceptions in blocking functions, including those wrapped in BaseExceptionGroup."""
"""Decorator to handle exceptions in blocking functions, including those wrapped in BaseExceptionGroup.

When the wrapped function raises a token-expired error, the decorator re-authenticates
via login_func and retries the function exactly once.
"""

def decorator(func):

Check failure on line 251 in python/packages/jumpstarter-cli-common/jumpstarter_cli_common/exceptions.py

View workflow job for this annotation

GitHub Actions / lint-python

ruff (C901)

python/packages/jumpstarter-cli-common/jumpstarter_cli_common/exceptions.py:251:9: C901 `decorator` is too complex (15 > 10)
@wraps(func)
def wrapped(*args, **kwargs):

Check failure on line 253 in python/packages/jumpstarter-cli-common/jumpstarter_cli_common/exceptions.py

View workflow job for this annotation

GitHub Actions / lint-python

ruff (C901)

python/packages/jumpstarter-cli-common/jumpstarter_cli_common/exceptions.py:253:13: C901 `wrapped` is too complex (14 > 10)
retry = False
try:
return func(*args, **kwargs)
except BaseExceptionGroup as eg:
_handle_exception_group_with_reauth(eg, login_func)
if _handle_exception_group_with_reauth(eg, login_func):
retry = True
except (ConnectionError, JumpstarterException, click.ClickException) as e:
_handle_single_exception_with_reauth(e, login_func)
if _handle_single_exception_with_reauth(e, login_func):
retry = True
except Exception as e:
if cli_exc := _map_cli_exception(e):
raise cli_exc from None
Expand All @@ -253,6 +269,18 @@
raise cli_exc from None
raise

if retry:
try:
return func(*args, **kwargs)
except Exception as e:
if cli_exc := _map_cli_exception(e):
raise cli_exc from None
raise
except KeyboardInterrupt as e:
if cli_exc := _map_cli_exception(e):
raise cli_exc from None
raise

return wrapped

return decorator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,53 @@ def grpc_auth_fn():
grpc_auth_fn()


def test_handle_exceptions_with_reauth_retries_on_expired_token() -> None:
"""After successful re-auth, the command is retried automatically."""
from jumpstarter.common.exceptions import ConnectionError

call_count = [0]

def login_func(config):
config.token = "new_token"

@handle_exceptions_with_reauthentication(login_func)
def command_fn(config=None):
call_count[0] += 1
if call_count[0] == 1:
exc = ConnectionError("token is expired")
exc.set_config(config)
raise exc
return "result"

config = type("Config", (), {"token": "old_token"})()
result = command_fn(config=config)

assert result == "result"
assert call_count[0] == 2


def test_handle_exceptions_with_reauth_does_not_retry_twice() -> None:
"""If retry also fails with expired token, the error propagates."""
from jumpstarter.common.exceptions import ConnectionError

login_calls = [0]

def login_func(config):
login_calls[0] += 1

@handle_exceptions_with_reauthentication(login_func)
def always_expired_fn(config=None):
exc = ConnectionError("token is expired")
exc.set_config(config)
raise exc

config = type("Config", (), {"token": "tok"})()
with pytest.raises(click.ClickException):
always_expired_fn(config=config)

assert login_calls[0] == 1


def test_handle_exceptions_maps_grpc_invalid_argument() -> None:
class MockGrpcError(Exception):
def code(self):
Expand Down
108 changes: 107 additions & 1 deletion python/packages/jumpstarter-cli-common/jumpstarter_cli_common/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,43 @@
import os
import ssl
import time
import warnings
from dataclasses import dataclass
from functools import wraps
from typing import ClassVar

import aiohttp
import anyio
import certifi
import click
from aiohttp import web
from anyio import create_memory_object_stream
from anyio.to_thread import run_sync

# Authlib still pulls deprecated authlib.jose on first OAuth2Session use, not only at
# import time; a transient catch_warnings() around the import does not suppress that.
warnings.filterwarnings(
"ignore",
message=r"authlib\.jose module is deprecated.*",
module=r"authlib\..*",
)

# When the user opts into --insecure-tls, we set verify=False on the requests session.
# urllib3 emits InsecureRequestWarning for every such request; suppress it since the
# user has already acknowledged the risk.
warnings.filterwarnings(
"ignore",
message=r"Unverified HTTPS request is being made.*",
module=r"urllib3\..*",
)
Comment on lines +34 to +38
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

InsecureRequestWarning is suppressed globally, not conditionally on --insecure-tls

The comment says "when the user opts into --insecure-tls" but the filter is applied unconditionally at module import time. Any unverified HTTPS request made anywhere in the process — including from unrelated code paths — will silently drop its warning for every user, not just those who opted in.

Since the authlib comment above (lines 18-19) correctly explains why the warning fires at runtime not at import time, the right fix for both warning filters is to move them after all imports. For the InsecureRequestWarning specifically, the suppression should be conditional:

🛡️ Proposed fix: conditional suppression

Remove the module-level InsecureRequestWarning filter entirely, and add it where insecure_tls=True is first used, e.g., in Config.client():

 def client(self, **kwargs):
+    if self.insecure_tls:
+        warnings.filterwarnings(
+            "ignore",
+            message=r"Unverified HTTPS request is being made.*",
+            module=r"urllib3\..*",
+        )
     session = OAuth2Session(client_id=self.client_id, scope=self._scopes(), **kwargs)
     session.verify = False if self.insecure_tls else (os.environ.get("SSL_CERT_FILE") or certifi.where())
     return session
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@python/packages/jumpstarter-cli-common/jumpstarter_cli_common/oidc.py` around
lines 29 - 33, The current module-level warnings.filterwarnings call that
silences Unverified HTTPS warnings must be removed and the suppression made
conditional where the insecure_tls option is applied; delete the global
warnings.filterwarnings(...) for InsecureRequestWarning in
jumpstarter_cli_common/oidc.py and instead add a conditional
warnings.filterwarnings("ignore", category=InsecureRequestWarning, ...) inside
the code path that consumes insecure_tls (e.g., inside Config.client() or
wherever Config.client(insecure_tls=...) is invoked) so the warning is only
suppressed when insecure_tls=True; keep the authlib-related filter moved so all
warning filters run after imports.

Comment on lines +34 to +38
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The InsecureRequestWarning suppression runs unconditionally at module import time, which means any code that imports oidc silently hides legitimate TLS errors even when the user never passed --insecure-tls.

Consider moving the warnings.filterwarnings call into Config.client() or Config.__init__, gated on self.insecure_tls, so the warning is only suppressed when the user explicitly opted in.

from authlib.integrations.requests_client import OAuth2Session

Check failure on line 34 in python/packages/jumpstarter-cli-common/jumpstarter_cli_common/oidc.py

View workflow job for this annotation

GitHub Actions / lint-python

ruff (E402)

python/packages/jumpstarter-cli-common/jumpstarter_cli_common/oidc.py:34:1: E402 Module level import not at top of file
from joserfc.jws import extract_compact

Check failure on line 35 in python/packages/jumpstarter-cli-common/jumpstarter_cli_common/oidc.py

View workflow job for this annotation

GitHub Actions / lint-python

ruff (E402)

python/packages/jumpstarter-cli-common/jumpstarter_cli_common/oidc.py:35:1: E402 Module level import not at top of file
from yarl import URL

Check failure on line 36 in python/packages/jumpstarter-cli-common/jumpstarter_cli_common/oidc.py

View workflow job for this annotation

GitHub Actions / lint-python

ruff (E402)

python/packages/jumpstarter-cli-common/jumpstarter_cli_common/oidc.py:36:1: E402 Module level import not at top of file

from jumpstarter.config.env import JMP_OIDC_CALLBACK_PORT
from jumpstarter.config.env import JMP_OIDC_CALLBACK_PORT, JMP_OIDC_DEVICE_FLOW

Check failure on line 38 in python/packages/jumpstarter-cli-common/jumpstarter_cli_common/oidc.py

View workflow job for this annotation

GitHub Actions / lint-python

ruff (E402)

python/packages/jumpstarter-cli-common/jumpstarter_cli_common/oidc.py:38:1: E402 Module level import not at top of file
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

DEVICE_FLOW_POLL_INTERVAL = 5
DEVICE_FLOW_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:device_code"


def _get_ssl_context() -> ssl.SSLContext:
Expand Down Expand Up @@ -46,6 +68,12 @@
default=True,
help="Request offline_access scope (refresh token)",
)
@click.option(
"--device-flow/--no-device-flow",
"device_flow",
default=None,
help="Use OAuth 2.0 Device Authorization Grant (auto-detected in Dev Spaces)",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nice! I didn't know this would be auto-detected in DevSpaces.

)
@wraps(f)
def wrapper(*args, **kwds):
return f(*args, **kwds)
Expand Down Expand Up @@ -177,6 +205,84 @@
lambda: client.fetch_token(config["token_endpoint"], authorization_response=authorization_response)
)

async def device_code_grant(self):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a lot of responsibility in one class.

Consider extracting grant-specific logic into standalone functions or strategy classes. For example, device_code_grant could be a standalone function that receives a Config and returns tokens.

config = await self.configuration()

device_endpoint = config.get("device_authorization_endpoint")
if not device_endpoint:
raise click.ClickException(
"OIDC provider does not support Device Authorization Grant. "
"Ensure 'oauth2DeviceAuthorizationGrantEnabled' is enabled on the OIDC client."
)

ssl_context: ssl.SSLContext | bool = False if self.insecure_tls else _get_ssl_context()
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(connector=connector) as session:
async with session.post(
device_endpoint,
data={
"client_id": self.client_id,
"scope": " ".join(self._scopes()),
},
) as resp:
if resp.status != 200:
body = await resp.text()
raise click.ClickException(f"Device authorization request failed (HTTP {resp.status}): {body}")
device_data = await resp.json()

device_code = device_data["device_code"]
user_code = device_data["user_code"]
verification_uri = device_data.get("verification_uri_complete") or device_data["verification_uri"]
expires_in = device_data.get("expires_in", 600)
interval = device_data.get("interval", DEVICE_FLOW_POLL_INTERVAL)

print(
f"\nOpen the following URL in your browser and enter the code:"
f"\n {verification_uri}"
f"\n Code: {user_code}\n"
)

token_endpoint = config["token_endpoint"]
deadline = time.monotonic() + expires_in

while time.monotonic() < deadline:
await anyio.sleep(interval)

client = self.client()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The device flow polling loop creates a new self.client() (OAuth2Session wrapping a requests.Session with a connection pool) on every iteration.

try:
token = await run_sync(
lambda c=client: c.fetch_token(
token_endpoint,
grant_type=DEVICE_FLOW_GRANT_TYPE,
device_code=device_code,
)
)
return token
except Exception as e:
error_msg = str(e)
if "authorization_pending" in error_msg:
continue
if "slow_down" in error_msg:
interval += 5
continue
if "expired_token" in error_msg or "expired" in error_msg:
raise click.ClickException("Device code expired. Please try again.") from e
if "access_denied" in error_msg:
raise click.ClickException("Authorization was denied by the user.") from e
raise click.ClickException(f"Device flow token request failed: {e}") from e
Comment on lines +261 to +272
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The authlib library typically provides structured OAuthError with an .error attribute. Consider checking for specific exception types and using structured attributes, falling back to string matching only for unknown types.


raise click.ClickException("Device code expired (timeout). Please try again.")


def should_use_device_flow(device_flow_flag: bool | None = None) -> bool:
if device_flow_flag is not None:
return device_flow_flag
if os.environ.get(JMP_OIDC_DEVICE_FLOW, "").strip() == "1":
return True
if os.environ.get("VSCODE_INJECTION", "").strip() == "1":
return True
return False


def decode_jwt(token: str):
try:
Expand Down
Loading
Loading