Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Update AccessToken #36406

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions sdk/core/azure-core/azure/core/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,38 @@
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
from dataclasses import dataclass, fields
from typing import Any, NamedTuple, Optional
from typing_extensions import Protocol, runtime_checkable


class AccessToken(NamedTuple):
@dataclass(frozen=True)
class AccessToken:
"""Represents an OAuth access token."""

token: str
"""The token string."""
expires_on: int
"""The token's expiration time in Unix time."""
refresh_on: Optional[int] = None
"""When the token should be refreshed in Unix time."""

def __iter__(self) -> Any:
"""Return an iterator that will yield non-None values.
AccessToken.token.__doc__ = """The token string."""
AccessToken.expires_on.__doc__ = """The token's expiration time in Unix time."""
AccessToken.refresh_on.__doc__ = """When the token should be refreshed in Unix time."""
This is backwards compatible with code unpacking the token and expires_on values of AccessToken.
:return: Iterator containing the token and expires_on values.
:rtype: Iterator
"""
# Note: `fields` returns a tuple of fields in the order they are defined in the class.
return (getattr(self, field.name) for field in fields(self) if getattr(self, field.name) is not None)

def __getitem__(self, index: Any) -> Any:
return tuple(self)[index]

def __len__(self) -> int:
return len(tuple(self.__iter__()))


@runtime_checkable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,8 @@ def _update_headers(headers: MutableMapping[str, str], token: str) -> None:
@property
def _need_new_token(self) -> bool:
now = time.time()
return (
not self._token
or (self._token.refresh_on is not None and self._token.refresh_on <= now)
or self._token.expires_on - now < 300
)
refresh_on = getattr(self._token, "refresh_on", None)
return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300


class BearerTokenCredentialPolicy(_BearerTokenCredentialPolicyBase, HTTPPolicy[HTTPRequestType, HTTPResponseType]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,5 @@ def on_exception(self, request: PipelineRequest[HTTPRequestType]) -> None:

def _need_new_token(self) -> bool:
now = time.time()
return (
not self._token
or (self._token.refresh_on is not None and self._token.refresh_on <= now)
or self._token.expires_on - now < 300
)
refresh_on = getattr(self._token, "refresh_on", None)
return not self._token or (refresh_on and refresh_on <= now) or self._token.expires_on - now < 300
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# license information.
# -------------------------------------------------------------------------
import asyncio
from collections import namedtuple
import sys
import time
from unittest.mock import Mock, patch
Expand Down Expand Up @@ -470,3 +471,51 @@ def test_async_token_credential_sync():
# Ensure trio isn't in sys.modules (i.e. imported).
sys.modules.pop("trio", None)
AsyncBearerTokenCredentialPolicy(Mock(), "scope")


def test_need_new_token():
expected_scope = "scope"
now = int(time.time())

policy = AsyncBearerTokenCredentialPolicy(Mock(), expected_scope)

# Token is expired.
policy._token = AccessToken("", now - 1200)
assert policy._need_new_token()

# Token is about to expire within 300 seconds.
policy._token = AccessToken("", now + 299)
assert policy._need_new_token()

# Token still has more than 300 seconds to live.
policy._token = AccessToken("", now + 305)
assert not policy._need_new_token()

# Token has both expires_on and refresh_on set well into the future.
policy._token = AccessToken("", now + 1200, now + 1200)
assert not policy._need_new_token()

# Token is not close to expiring, but refresh_on is in the past.
policy._token = AccessToken("", now + 1200, now - 1)
assert policy._need_new_token()

policy._token = None
assert policy._need_new_token()


def test_need_new_token_with_external_defined_token_class():
"""Test the case where some custom credential get_token call returns a custom token object."""
FooAccessToken = namedtuple("FooAccessToken", ["token", "expires_on"])

expected_scope = "scope"
now = int(time.time())

policy = AsyncBearerTokenCredentialPolicy(Mock(), expected_scope)

# Token is expired.
policy._token = FooAccessToken("", now - 1200)
assert policy._need_new_token()

# Token is about to expire within 300 seconds.
policy._token = FooAccessToken("", now + 299)
assert policy._need_new_token()
66 changes: 66 additions & 0 deletions sdk/core/azure-core/tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# license information.
# -------------------------------------------------------------------------
import time
from collections import namedtuple
from itertools import product
from requests import Response
import azure.core
Expand Down Expand Up @@ -331,6 +332,27 @@ def test_need_new_token():
policy._token = AccessToken("", now + 1200, now - 1)
assert policy._need_new_token

policy._token = None
assert policy._need_new_token


def test_need_new_token_with_external_defined_token_class():
"""Test the case where some custom credential get_token call returns a custom token object."""
FooAccessToken = namedtuple("FooAccessToken", ["token", "expires_on"])

expected_scope = "scope"
now = int(time.time())

policy = BearerTokenCredentialPolicy(Mock(), expected_scope)

# Token is expired.
policy._token = FooAccessToken("", now - 1200)
assert policy._need_new_token

# Token is about to expire within 300 seconds.
policy._token = FooAccessToken("", now + 299)
assert policy._need_new_token


@pytest.mark.parametrize("http_request", HTTP_REQUESTS)
def test_azure_key_credential_policy(http_request):
Expand Down Expand Up @@ -638,3 +660,47 @@ def verify_authorization_header(request):
pipeline = Pipeline(transport=transport, policies=[credential_policy])

pipeline.run(http_request("GET", "https://test_key_credential"))


def test_access_token_unpack():
"""Test various unpacking of AccessToken."""
token = AccessToken("token", 42)
assert token.token == "token"
assert token.expires_on == 42
assert token.refresh_on is None

token, expires_on = AccessToken("token", 42)
assert token == "token"
assert expires_on == 42

token, expires_on = AccessToken(token="token", expires_on=42)
assert token == "token"
assert expires_on == 42

token, expires_on, refresh_on = AccessToken("token", 42, 21)
assert token == "token"
assert expires_on == 42
assert refresh_on == 21

token, expires_on, refresh_on = AccessToken(token="token", expires_on=42, refresh_on=21)
assert token == "token"
assert expires_on == 42
assert refresh_on == 21

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also want to test the case:

token, expires_on = AccessToken(token="token", expires_on=42, refresh_on=21)
assert token == "token"
assert expires_on == 42


def test_access_token_subscriptable():
"""Test that AccessToken can be indexed by position. This is verify backwards-compatibility."""
token = AccessToken("token", 42)
assert token[0] == "token"
assert token[1] == 42
assert token[-1] == 42
assert len(token) == 2

token = AccessToken("token", 42, 100)
assert token[:1] == ("token",)
assert token[:2] == ("token", 42)
assert token[:3] == ("token", 42, 100)
assert token[:-1] == ("token", 42)
assert len(token) == 3
with pytest.raises(IndexError):
_ = token[3]
Loading