Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] AccessToken subclassing #36464

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 51 additions & 5 deletions sdk/core/azure-core/azure/core/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,65 @@
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
from typing import Any, NamedTuple, Optional
from typing import Any, NamedTuple, Optional, Dict
from typing_extensions import Protocol, runtime_checkable


class AccessToken(NamedTuple):
"""Represents an OAuth access token."""

token: str
"""The token string."""
expires_on: int
refresh_on: Optional[int] = None
"""The token's expiration time in Unix time."""


AccessToken.token.__doc__ = """The token string."""
AccessToken.expires_on.__doc__ = """The token's expiration time in Unix time."""
AccessToken.refresh_on.__doc__ = """When the token should be refreshed in Unix time."""
class ExtendedAccessToken(AccessToken):
"""Represents an OAuth access token with additional properties.

Currently, the only additional property is the token's refresh on time.

In order to maintain backwards compatibility with the base AccessToken, only properties defined
in the base class are used for operations like comparison and iteration.

:param str token: The token string.
:param int expires_on: The token's expiration time in Unix time.
:keyword int refresh_on: The token's refresh on time in Unix time. Optional.
"""

_refresh_on: Optional[int]

def __new__(cls, token: str, expires_on: int, *, refresh_on: Optional[int] = None):
instance = super().__new__(cls, token, expires_on)
instance._refresh_on = refresh_on
return instance

@property
def refresh_on(self) -> Optional[int]:
"""The token's refresh on time in Unix time.

:rtype: Optional[int]
"""
return self._refresh_on


class ExtensionAccessToken(AccessToken):
"""Exploratory class for storing arbitrary token metadata."""

_additional_properties: Dict[str, Any]

def __new__(cls, token: str, expires_on: int, **kwargs: Any):
instance = super().__new__(cls, token, expires_on)
instance._additional_properties = kwargs
return instance

@property
def additional_properties(self) -> Dict[str, Any]:
"""The additional properties or metadata of the token.

:rtype: dict[str, Any]
"""
return self._additional_properties


@runtime_checkable
Expand Down Expand Up @@ -59,6 +103,8 @@ class AzureNamedKey(NamedTuple):
"AzureKeyCredential",
"AzureSasCredential",
"AccessToken",
"ExtensionAccessToken",
"ExtendedAccessToken",
"AzureNamedKeyCredential",
"TokenCredential",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,8 @@ def _update_headers(headers: MutableMapping[str, str], token: str) -> None:
@property
def _need_new_token(self) -> bool:
now = time.time()
return (
not self._token
or (self._token.refresh_on is not None and self._token.refresh_on <= now)
or self._token.expires_on - now < 300
)
refresh_on = getattr(self._token, "refresh_on", None)
return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300


class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,5 @@ def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None:

def _need_new_token(self) -> bool:
now = time.time()
return (
not self._token
or (self._token.refresh_on is not None and self._token.refresh_on <= now)
or self._token.expires_on - now < 300
)
refresh_on = getattr(self._token, "refresh_on", None)
return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300
79 changes: 76 additions & 3 deletions sdk/core/azure-core/tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,18 @@
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
from collections import namedtuple
import time
from itertools import product
from requests import Response
import azure.core
from azure.core.credentials import AccessToken, AzureKeyCredential, AzureSasCredential, AzureNamedKeyCredential
from azure.core.credentials import (
AccessToken,
AzureKeyCredential,
AzureSasCredential,
AzureNamedKeyCredential,
ExtendedAccessToken,
)
from azure.core.exceptions import ServiceRequestError
from azure.core.pipeline import Pipeline
from azure.core.pipeline.transport import HttpTransport, HttpRequest
Expand Down Expand Up @@ -324,11 +331,32 @@ def test_need_new_token():
assert not policy._need_new_token

# Token has both expires_on and refresh_on set well into the future.
policy._token = AccessToken("", now + 1200, now + 1200)
policy._token = ExtendedAccessToken("", now + 1200, refresh_on=now + 1200)
assert not policy._need_new_token

# Token is not close to expiring, but refresh_on is in the past.
policy._token = AccessToken("", now + 1200, now - 1)
policy._token = ExtendedAccessToken("", now + 1200, refresh_on=now - 1)
assert policy._need_new_token

policy._token = None
assert policy._need_new_token


def test_need_new_token_with_external_defined_token_class():
"""Test the case where some custom credential get_token call returns a custom token object."""
FooAccessToken = namedtuple("FooAccessToken", ["token", "expires_on"])

expected_scope = "scope"
now = int(time.time())

policy = BearerTokenCredentialPolicy(Mock(), expected_scope)

# Token is expired.
policy._token = FooAccessToken("", now - 1200)
assert policy._need_new_token

# Token is about to expire within 300 seconds.
policy._token = FooAccessToken("", now + 299)
assert policy._need_new_token


Expand Down Expand Up @@ -638,3 +666,48 @@ def verify_authorization_header(request):
pipeline = Pipeline(transport=transport, policies=[credential_policy])

pipeline.run(http_request("GET", "https://test_key_credential"))


def test_extended_access_token_unpack():
"""Test various unpacking of AccessToken."""
token = ExtendedAccessToken("token", 42)
assert token.token == "token"
assert token.expires_on == 42
assert token.refresh_on is None

token, expires_on = ExtendedAccessToken(token="token", expires_on=42, refresh_on=21)
assert token == "token"
assert expires_on == 42

token, expires_on = ExtendedAccessToken("token", 42)
assert token == "token"
assert expires_on == 42

token, expires_on = ExtendedAccessToken(token="token", expires_on=42)
assert token == "token"
assert expires_on == 42

token, expires_on = ExtendedAccessToken("token", 42, refresh_on=21)
assert token == "token"
assert expires_on == 42

token, expires_on = ExtendedAccessToken(token="token", expires_on=42, refresh_on=21)
assert token == "token"
assert expires_on == 42


def test_access_token_subscriptable():
"""Test that AccessToken can be indexed by position. This is verify backwards-compatibility."""
token = ExtendedAccessToken("token", 42)
assert token[0] == "token"
assert token[1] == 42
assert token[-1] == 42
assert len(token) == 2

token = ExtendedAccessToken("token", 42, refresh_on=100)
assert token[:1] == ("token",)
assert token[:2] == ("token", 42)
assert token[:-1] == ("token",)
assert len(token) == 2
with pytest.raises(IndexError):
_ = token[2]