Skip to content

Commit

Permalink
Add more dunder methods
Browse files Browse the repository at this point in the history
Signed-off-by: Paul Van Eck <[email protected]>
  • Loading branch information
pvaneck committed Jul 10, 2024
1 parent 1c3d671 commit f7b332f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
24 changes: 23 additions & 1 deletion sdk/core/azure-core/azure/core/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Licensed under the MIT License. See LICENSE.txt in the project root for
# license information.
# -------------------------------------------------------------------------
from dataclasses import dataclass
from dataclasses import dataclass, fields
from typing import Any, NamedTuple, Optional
from typing_extensions import Protocol, runtime_checkable

Expand All @@ -30,6 +30,28 @@ def __iter__(self) -> Any:
yield self.token
yield self.expires_on

def __getitem__(self, index: Any) -> Any:
"""Access the value at the given index.
:param Any index: The index to access.
:return: The value at the given index.
:rtype: Any
"""
token_fields = fields(self)
field = token_fields[index]
if isinstance(field, tuple):
# Handle slices
return tuple(getattr(self, field[i].name) for i in range(len(field)))
return getattr(self, field.name)

def __len__(self) -> int:
"""Return the number of fields in the AccessToken.
:return: The number of fields in the AccessToken.
:rtype: int
"""
return len(fields(self))


@runtime_checkable
class TokenCredential(Protocol):
Expand Down
21 changes: 20 additions & 1 deletion sdk/core/azure-core/tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,8 +642,27 @@ def verify_authorization_header(request):

def test_access_token_unpack_backwards_compat():
"""Test that AccessToken can be unpacked as a two-value tuple for backwards compatibility."""
_ = AccessToken("token", 42)
token = AccessToken("token", 42)
assert token.token == "token"
assert token.expires_on == 42
assert token.refresh_on is None

_, _ = AccessToken("token", 42)
_, _ = AccessToken(token="token", expires_on=42)
_, _ = AccessToken("token", 42, 42)
_, _ = AccessToken(token="token", expires_on=42, refresh_on=42)


def test_access_token_subscriptable():
"""Test that AccessToken can be indexed by position. This is verify backwards-compatibility."""
token = AccessToken("token", 42)
assert token[0] == "token"
assert token[1] == 42
assert token[2] is None
token = AccessToken("token", 42, 100)
assert token[:2] == ("token", 42)
assert token[:3] == ("token", 42, 100)
assert token[:-1] == ("token", 42)
assert len(token) == 3
with pytest.raises(IndexError):
_ = token[3]

0 comments on commit f7b332f

Please sign in to comment.