Skip to content

Commit

Permalink
[Core] Convert AccessToken to dataclass
Browse files Browse the repository at this point in the history
This allows adding a new field to AccessToken without breaking
unpacking or index-based access of the previous credential.

Signed-off-by: Paul Van Eck <[email protected]>
  • Loading branch information
pvaneck committed Jul 11, 2024
1 parent f812aaa commit 7cba5d5
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 14 deletions.
25 changes: 21 additions & 4 deletions sdk/core/azure-core/azure/core/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,38 @@
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
from dataclasses import dataclass, fields
from typing import Any, NamedTuple, Optional
from typing_extensions import Protocol, runtime_checkable


class AccessToken(NamedTuple):
@dataclass(frozen=True)
class AccessToken:
"""Represents an OAuth access token."""

token: str
"""The token string."""
expires_on: int
"""The token's expiration time in Unix time."""
refresh_on: Optional[int] = None
"""When the token should be refreshed in Unix time."""

def __iter__(self) -> Any:
"""Return an iterator that will yield non-None values.
AccessToken.token.__doc__ = """The token string."""
AccessToken.expires_on.__doc__ = """The token's expiration time in Unix time."""
AccessToken.refresh_on.__doc__ = """When the token should be refreshed in Unix time."""
This is backwards compatible with code unpacking the token and expires_on values of AccessToken.
:return: Iterator containing the token and expires_on values.
:rtype: Iterator
"""
# Note: `fields` returns a tuple of fields in the order they are defined in the class.
return (getattr(self, field.name) for field in fields(self) if getattr(self, field.name) is not None)

def __getitem__(self, index: Any) -> Any:
return tuple(self)[index]

def __len__(self) -> int:
return len(tuple(self.__iter__()))


@runtime_checkable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,8 @@ def _update_headers(headers: MutableMapping[str, str], token: str) -> None:
@property
def _need_new_token(self) -> bool:
now = time.time()
return (
not self._token
or (self._token.refresh_on is not None and self._token.refresh_on <= now)
or self._token.expires_on - now < 300
)
refresh_on = getattr(self._token, "refresh_on", None)
return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300


class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,5 @@ def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None:

def _need_new_token(self) -> bool:
now = time.time()
return (
not self._token
or (self._token.refresh_on is not None and self._token.refresh_on <= now)
or self._token.expires_on - now < 300
)
refresh_on = getattr(self._token, "refresh_on", None)
return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300
49 changes: 49 additions & 0 deletions sdk/core/azure-core/tests/async_tests/test_authentication_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# license information.
# -------------------------------------------------------------------------
import asyncio
from collections import namedtuple
import sys
import time
from unittest.mock import Mock, patch
Expand Down Expand Up @@ -470,3 +471,51 @@ def test_async_token_credential_sync():
# Ensure trio isn't in sys.modules (i.e. imported).
sys.modules.pop("trio", None)
AsyncBearerTokenCredentialPolicy(Mock(), "scope")


def test_need_new_token():
expected_scope = "scope"
now = int(time.time())

policy = AsyncBearerTokenCredentialPolicy(Mock(), expected_scope)

# Token is expired.
policy._token = AccessToken("", now - 1200)
assert policy._need_new_token()

# Token is about to expire within 300 seconds.
policy._token = AccessToken("", now + 299)
assert policy._need_new_token()

# Token still has more than 300 seconds to live.
policy._token = AccessToken("", now + 305)
assert not policy._need_new_token()

# Token has both expires_on and refresh_on set well into the future.
policy._token = AccessToken("", now + 1200, now + 1200)
assert not policy._need_new_token()

# Token is not close to expiring, but refresh_on is in the past.
policy._token = AccessToken("", now + 1200, now - 1)
assert policy._need_new_token()

policy._token = None
assert policy._need_new_token()


def test_need_new_token_with_external_defined_token_class():
"""Test the case where some custom credential get_token call returns a custom token object."""
FooAccessToken = namedtuple("FooAccessToken", ["token", "expires_on"])

expected_scope = "scope"
now = int(time.time())

policy = AsyncBearerTokenCredentialPolicy(Mock(), expected_scope)

# Token is expired.
policy._token = FooAccessToken("", now - 1200)
assert policy._need_new_token()

# Token is about to expire within 300 seconds.
policy._token = FooAccessToken("", now + 299)
assert policy._need_new_token()
66 changes: 66 additions & 0 deletions sdk/core/azure-core/tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# license information.
# -------------------------------------------------------------------------
import time
from collections import namedtuple
from itertools import product
from requests import Response
import azure.core
Expand Down Expand Up @@ -331,6 +332,27 @@ def test_need_new_token():
policy._token = AccessToken("", now + 1200, now - 1)
assert policy._need_new_token

policy._token = None
assert policy._need_new_token


def test_need_new_token_with_external_defined_token_class():
"""Test the case where some custom credential get_token call returns a custom token object."""
FooAccessToken = namedtuple("FooAccessToken", ["token", "expires_on"])

expected_scope = "scope"
now = int(time.time())

policy = BearerTokenCredentialPolicy(Mock(), expected_scope)

# Token is expired.
policy._token = FooAccessToken("", now - 1200)
assert policy._need_new_token

# Token is about to expire within 300 seconds.
policy._token = FooAccessToken("", now + 299)
assert policy._need_new_token


@pytest.mark.parametrize("http_request", HTTP_REQUESTS)
def test_azure_key_credential_policy(http_request):
Expand Down Expand Up @@ -638,3 +660,47 @@ def verify_authorization_header(request):
pipeline = Pipeline(transport=transport, policies=[credential_policy])

pipeline.run(http_request("GET", "https://test_key_credential"))


def test_access_token_unpack():
"""Test various unpacking of AccessToken."""
token = AccessToken("token", 42)
assert token.token == "token"
assert token.expires_on == 42
assert token.refresh_on is None

token, expires_on = AccessToken("token", 42)
assert token == "token"
assert expires_on == 42

token, expires_on = AccessToken(token="token", expires_on=42)
assert token == "token"
assert expires_on == 42

token, expires_on, refresh_on = AccessToken("token", 42, 21)
assert token == "token"
assert expires_on == 42
assert refresh_on == 21

token, expires_on, refresh_on = AccessToken(token="token", expires_on=42, refresh_on=21)
assert token == "token"
assert expires_on == 42
assert refresh_on == 21


def test_access_token_subscriptable():
"""Test that AccessToken can be indexed by position. This is verify backwards-compatibility."""
token = AccessToken("token", 42)
assert token[0] == "token"
assert token[1] == 42
assert token[-1] == 42
assert len(token) == 2

token = AccessToken("token", 42, 100)
assert token[:1] == ("token",)
assert token[:2] == ("token", 42)
assert token[:3] == ("token", 42, 100)
assert token[:-1] == ("token", 42)
assert len(token) == 3
with pytest.raises(IndexError):
_ = token[3]

0 comments on commit 7cba5d5

Please sign in to comment.