Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor auth tokens retrieval/refresh process #85

Merged
merged 7 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,11 @@
nitpicky = True

nitpick_ignore_regex = {
(r"py:obj", "ansys.simai.core.data.base.DataModelType"),
(r"py:class", "_io.BytesIO"),
(r"py:class", "pydantic_core._pydantic_core.Url"),
("py:obj", "ansys.simai.core.data.base.DataModelType"),
("py:class", "_io.BytesIO"),
("py:class", "pydantic_core._pydantic_core.Url"),
("py:class", "pydantic_core._pydantic_core.Annotated"),
("py:class", "pydantic.networks.UrlConstraints"),
}

source_suffix = ".rst"
Expand Down
883 changes: 514 additions & 369 deletions pdm.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies = [
"sseclient-py>=1.8.0,<3",
"wakepy>=0.8.0",
"tqdm>=4.66.1",
"filelock>=3.10.7",
]

[project.module]
Expand Down
268 changes: 144 additions & 124 deletions src/ansys/simai/core/utils/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,22 @@

import json
import logging
import os
import random
import threading
import time
import webbrowser
from datetime import datetime, timedelta
from typing import Optional
from datetime import datetime, timedelta, timezone
from typing import Any, ClassVar, Optional
from urllib.parse import urljoin

import requests
from pydantic import BaseModel, ValidationError
from filelock import FileLock
from pydantic import BaseModel, ValidationError, computed_field, model_validator
from requests.auth import AuthBase
from requests.models import PreparedRequest

from ansys.simai.core.errors import ConnectionError
from ansys.simai.core.errors import ApiClientError
from ansys.simai.core.utils.configuration import ClientConfig, Credentials
from ansys.simai.core.utils.files import get_cache_dir
from ansys.simai.core.utils.requests import handle_response
Expand All @@ -42,75 +46,105 @@


class _AuthTokens(BaseModel):
"""Represents the OIDC tokens received from the auth server.
"""Represents the OIDC tokens received from the auth server."""

The class can fetch and refresh these tokens.
It caches the tokens to disk automatically.
"""

_EXPIRATION_BUFFER = timedelta(seconds=5)
# Time buffer so we treat nearly invalid tokens as invalid.
# Random so multiple processes don't refresh tokens at the same time.
EXPIRATION_BUFFER: ClassVar = random.randrange(5, 15) # noqa: S311

cache_file: Optional[str] = None
access_token: str
expires_in: int
refresh_expires_in: int
expiration: datetime = None
refresh_expiration: datetime = None
expiration: datetime
refresh_expiration: datetime
refresh_token: str
# ... (unused fields were removed for simplicity)

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
if self.expiration is None:
now = datetime.utcnow()
self.expiration = now + timedelta(seconds=self.expires_in)
self.refresh_expiration = now + timedelta(seconds=self.refresh_expires_in)
if self.cache_file is not None:
with open(self.cache_file, "w") as f:
f.write(self.model_dump_json())
# ... (unused fields removed)
yias marked this conversation as resolved.
Show resolved Hide resolved

@model_validator(mode="before")
@classmethod
def from_cache(cls, file_path: Optional[str]) -> "Optional[_AuthTokens]":
if file_path is None:
return None
def expires_in_to_datetime(cls, data: Any) -> dict:
assert isinstance(data, dict)
awoimbee marked this conversation as resolved.
Show resolved Hide resolved
if "expiration" not in data:
# We want to store "expiration" but API responses contains "expires_in"
now = datetime.now(timezone.utc)
data["expiration"] = now + timedelta(seconds=int(data["expires_in"]))
data["refresh_expiration"] = now + timedelta(seconds=int(data["refresh_expires_in"]))
return data

def is_token_expired(self):
return self.expires_in < self.EXPIRATION_BUFFER

def is_refresh_token_expired(self):
return self.refresh_expires_in < self.EXPIRATION_BUFFER

@computed_field
@property
def expires_in(self) -> float:
return (self.expiration - datetime.now(timezone.utc)).total_seconds()

@computed_field
@property
def refresh_expires_in(self) -> float:
return (self.refresh_expiration - datetime.now(timezone.utc)).total_seconds()


class _AuthTokensRetriever:
"""Retrieve tokens via ``get_tokens()``.
It handles caching and the various auth token sources.
"""

# Time buffer to refresh the refresh_token before it becomes invalid.
# Random so multiple processes don't refresh the token at the same time
REFRESH_BUFFER = random.randrange(50, 120) # noqa: S311

def __init__(
self,
credentials: Optional["Credentials"],
session: requests.Session,
auth_cache_hash: str,
realm_url: str,
) -> None:
self.credentials = credentials
self.session = session
self.token_url = f"{realm_url}/protocol/openid-connect/token"
self.device_auth_url = f"{realm_url}/protocol/openid-connect/auth/device"
self.refresh_timer = threading.Timer(0, lambda: None)
self.cache_file_path = str(get_cache_dir() / f"tokens-{auth_cache_hash}.json")

def _get_token_from_cache(self) -> Optional[_AuthTokens]:
try:
auth = cls.parse_file(file_path)
# don't return expired auth tokens
if not auth.is_refresh_token_expired():
return auth
except (IOError, json.JSONDecodeError) as e:
logger.debug(f"Could not read auth cache: {e}")
except ValidationError as e:
logger.warning(f"ValidationError while reading auth cache: {e}")
with open(self.cache_file_path, "r") as f:
return _AuthTokens.model_validate_json(f.read())
except (IOError, json.JSONDecodeError, ValidationError) as e:
logger.info(f"Could not read auth token from cache: {e}")
return None

@classmethod
def from_request_direct_grant(
cls, session: requests.Session, token_url, creds: Credentials
) -> "_AuthTokens":
logger.debug("Getting authentication tokens.")
def _request_token_direct_grant(self) -> "_AuthTokens":
logger.debug("request authentication tokens via direct grant")
assert self.credentials
request_params = {
"client_id": "sdk",
"grant_type": "password",
"scope": "openid",
**creds.model_dump(),
**self.credentials.model_dump(),
}
return cls(**handle_response(session.post(token_url, data=request_params)))
return _AuthTokens(
**handle_response(self.session.post(self.token_url, data=request_params))
)

@classmethod
def from_request_device_auth(
cls, session: requests.Session, device_auth_url, token_url, cache_file=None
) -> "_AuthTokens":
def _request_token_device_auth(self) -> "_AuthTokens":
logger.debug("request authentication tokens via device auth")
auth_codes = handle_response(
session.post(device_auth_url, data={"client_id": "sdk", "scope": "openid"})
self.session.post(self.device_auth_url, data={"client_id": "sdk", "scope": "openid"})
)
print( # noqa: T201
f"Go to {auth_codes['verification_uri']} and enter the code {auth_codes['user_code']}"
)
webbrowser.open(auth_codes["verification_uri_complete"])
# loop will exit when auth server returns "400 Device code is expired"
while True:
validation = session.post(
token_url,
# polling can't be faster or auth server returns HTTP 400 Slow down
time.sleep(5)
validation = self.session.post(
self.token_url,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
Expand All @@ -119,54 +153,59 @@
},
)
if b"authorization_pending" not in validation.content:
return cls(**handle_response(validation), cache_file=cache_file)
# polling can't be faster or auth server returns HTTP 400 Slow down
time.sleep(5)

def is_token_expired(self):
return datetime.utcnow() > (self.expiration - self._EXPIRATION_BUFFER)

def is_refresh_token_expired(self):
return datetime.utcnow() > (self.refresh_expiration - self._EXPIRATION_BUFFER)
return _AuthTokens(**handle_response(validation))

def refresh(self, session: requests.Session, token_url: str) -> "_AuthTokens":
def _refresh_auth_token(self, refresh_token: str) -> Optional[_AuthTokens]:
logger.debug("Refreshing authentication tokens.")
request_params = {
"client_id": "sdk",
"grant_type": "refresh_token",
"refresh_token": self.refresh_token,
"refresh_token": refresh_token,
}
try:
return _AuthTokens(
**handle_response(session.post(token_url, data=request_params)),
cache_file=self.cache_file,
**handle_response(self.session.post(self.token_url, data=request_params))
)
except requests.exceptions.ConnectionError as e:
raise ConnectionError(e) from None


def _get_cached_or_request_device_auth(
session: requests.Session,
device_auth_url: str,
token_url: str,
auth_cache_hash: Optional[str],
):
auth = None
auth_cache_path = None
if auth_cache_hash: # Try to fetch auth from cache
auth_cache_path = str(get_cache_dir() / f"tokens-{auth_cache_hash}.json")
auth = _AuthTokens.from_cache(auth_cache_path)
if auth:
try:
auth = auth.refresh(session, token_url)
except Exception as e:
logger.debug(f"Auth refresh error: {e}")
auth = None
if auth is None: # Otherwise request the user to authenticate
auth = _AuthTokens.from_request_device_auth(
session, device_auth_url, token_url, auth_cache_path
except (requests.exceptions.ConnectionError, ApiClientError) as e:
logger.error(f"Could not refresh authentication tokens: {e}")
return None

def _schedule_auth_refresh(self, refresh_expires_in: float):
"""Schedule authentication refresh to avoids refresh token expiring if the client is idle for a long time."""
if refresh_expires_in <= self.REFRESH_BUFFER:
# Skip scheduling: refresh token too close to expiration (max SSO session nearly reached)
return

Check warning on line 177 in src/ansys/simai/core/utils/auth.py

View check run for this annotation

Codecov / codecov/patch

src/ansys/simai/core/utils/auth.py#L177

Added line #L177 was not covered by tests
self.refresh_timer.cancel()
self.refresh_timer = threading.Timer(
refresh_expires_in - self.REFRESH_BUFFER, self.get_tokens
)
return auth
self.refresh_timer.daemon = True
self.refresh_timer.start()

def get_tokens(self, force_refresh=False) -> _AuthTokens:
auth = self._get_token_from_cache()
if auth and not auth.is_token_expired() and not force_refresh:
# fast path: avoid locking the tokens, return early
self._schedule_auth_refresh(auth.refresh_expires_in)
return auth

Check warning on line 190 in src/ansys/simai/core/utils/auth.py

View check run for this annotation

Codecov / codecov/patch

src/ansys/simai/core/utils/auth.py#L189-L190

Added lines #L189 - L190 were not covered by tests
with FileLock(self.cache_file_path + ".lock", timeout=600):
# slow path: tokens are locked, will get refreshed
auth = self._get_token_from_cache() # might have changed while we waited the lock
if auth and auth.is_refresh_token_expired():
auth = None

Check warning on line 195 in src/ansys/simai/core/utils/auth.py

View check run for this annotation

Codecov / codecov/patch

src/ansys/simai/core/utils/auth.py#L195

Added line #L195 was not covered by tests
if auth and (force_refresh or auth.is_token_expired()):
auth = self._refresh_auth_token(auth.refresh_token)
if auth is None:
if self.credentials:
auth = self._request_token_direct_grant()
else:
auth = self._request_token_device_auth()
with open(self.cache_file_path + "~", "w") as f:
f.write(auth.model_dump_json())
# rename is atomic
os.replace(self.cache_file_path + "~", self.cache_file_path)
self._schedule_auth_refresh(auth.refresh_expires_in)
return auth


class Authenticator(AuthBase):
Expand All @@ -179,23 +218,17 @@
self._url_prefix = config.url
# HACK: start with a slash to override the /v2/ on the api url
self._realm_url = urljoin(str(config.url), "/auth/realms/simai")
self._token_url = f"{self._realm_url}/protocol/openid-connect/token"
self._organization_name = config.organization
self._refresh_timer = None
if config.credentials:
auth = _AuthTokens.from_request_direct_grant(
self._session, self._token_url, config.credentials
)
else:
device_auth_url = f"{self._realm_url}/protocol/openid-connect/auth/device"
auth_hash = None if config.disable_cache else config._auth_hash()
auth = _get_cached_or_request_device_auth(
self._session, device_auth_url, self._token_url, auth_hash
)
self._authentication = auth
self._schedule_auth_refresh()
auth_hash = config._auth_hash()
self.tokens_retriever = _AuthTokensRetriever(
config.credentials, session, auth_hash, self._realm_url
)
self.tokens_retriever.get_tokens(
force_refresh=True
) # start fetching/refreshing auth tokens

def __call__(self, request: requests.Request) -> requests.Request:
def __call__(self, request: PreparedRequest) -> PreparedRequest:
"""Call to prepare the requests.

Args:
Expand All @@ -204,31 +237,18 @@
Returns:
Request with the authentication.
"""
assert request.url
request_host = request.url.split("://", 1)[-1] # ignore protocol part
if self._enabled and request_host.startswith(self._url_prefix.host):
is_token_expired = self._authentication.is_token_expired()
if (
self._enabled
and request_host.startswith(self._url_prefix.host)
and self._realm_url not in request.url
):
# So the token doesn't expire during requests that upload a large amount of data
is_request_multipart_data = "multipart/form_data" in request.headers.get(
"Content-Type", ""
)
is_auth_request = self._realm_url in request.url
if not is_auth_request and is_token_expired or is_request_multipart_data:
self._refresh_auth()
request.headers["Authorization"] = f"Bearer {self._authentication.access_token}"
auth = self.tokens_retriever.get_tokens(force_refresh=is_request_multipart_data)
request.headers["Authorization"] = f"Bearer {auth.access_token}"
request.headers["X-Org"] = self._organization_name
return request

def _refresh_auth(self):
"""Refresh the authentication."""
self._authentication = self._authentication.refresh(self._session, self._token_url)
self._schedule_auth_refresh()

def _schedule_auth_refresh(self):
"""Schedule authentication refresh to avoids refresh token expiring if the client is idle for a long time."""
if self._refresh_timer:
self._refresh_timer.cancel()
self._refresh_timer = threading.Timer(
self._authentication.refresh_expires_in - 60, self._refresh_auth
)
self._refresh_timer.daemon = True
self._refresh_timer.start()
6 changes: 1 addition & 5 deletions src/ansys/simai/core/utils/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
ValidationInfo,
field_validator,
model_validator,
validator,
)
from pydantic_core import PydanticCustomError

Expand Down Expand Up @@ -110,11 +109,8 @@ class ClientConfig(BaseModel, extra="allow"):
no_sse_connection: bool = Field(
default=False, description="Don't receive live updates from the SimAI API."
)
disable_cache: bool = Field(
default=False, description="Don't use cached authentication tokens."
)

@validator("url", pre=True)
@field_validator("url", mode="before")
def clean_url(cls, url):
if isinstance(url, bytes):
url = url.decode()
Expand Down
Loading
Loading