Skip to content

Commit

Permalink
CAE for MIv1
Browse files Browse the repository at this point in the history
CAE team and MSI team are working on turning on CAE by default for MSI
v1. So what that means is, App developers will start seeing CAE even
without setting the capability - "CP1".
  • Loading branch information
rayluo committed Aug 7, 2024
1 parent 95e1bb0 commit f0833ee
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 15 deletions.
8 changes: 5 additions & 3 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,9 +411,11 @@ def __init__(
(STS) what this client is capable for,
so STS can decide to turn on certain features.
For example, if client is capable to handle *claims challenge*,
STS can then issue CAE access tokens to resources
knowing when the resource emits *claims challenge*
the client will be capable to handle.
STS may issue
`Continuous Access Evaluation (CAE) <https://learn.microsoft.com/en-us/entra/identity/conditional-access/concept-continuous-access-evaluation>`_
access tokens to resources,
knowing that when the resource emits *claims challenge*
the client will be capable to handle those challenges.
Implementation details:
Client capability is implemented using "claims" parameter on the wire,
Expand Down
33 changes: 28 additions & 5 deletions msal/managed_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time
from urllib.parse import urlparse # Python 3+
from collections import UserDict # Python 3+
from typing import Union # Needed in Python 3.7 & 3.8
from typing import Optional, Union # Needed in Python 3.7 & 3.8
from .token_cache import TokenCache
from .individual_cache import _IndividualCache as IndividualCache
from .throttled_http_client import ThrottledHttpClientBase, RetryAfterParser
Expand Down Expand Up @@ -145,6 +145,9 @@ class ManagedIdentityClient(object):
not a token with application permissions for an app.
"""
__instance, _tenant = None, "managed_identity" # Placeholders
_TOKEN_SOURCE = "token_source"
_TOKEN_SOURCE_IDP = "identity_provider"
_TOKEN_SOURCE_CACHE = "cache"

def __init__(
self,
Expand Down Expand Up @@ -237,12 +240,31 @@ def _get_instance(self):
self.__instance = socket.getfqdn() # Moved from class definition to here
return self.__instance

def acquire_token_for_client(self, *, resource): # We may support scope in the future
def acquire_token_for_client(
self,
*,
resource: str, # If/when we support scope, resource will become optional
claims_challenge: Optional[str] = None,
):
"""Acquire token for the managed identity.
The result will be automatically cached.
Subsequent calls will automatically search from cache first.
:param resource: The resource for which the token is acquired.
:param claims_challenge:
Optional.
It is a string of a JSON object
(which contains lists of claims being requested).
The tenant admin may choose to revoke all Managed Identity tokens,
and then *claims challenge* will be thrown by the target resource,
as a `claims_challenge` directive in the `www-authenticate` header,
even if the app developer did not opt in for client capability "CP1".
Upon receiving the claims_challenge, MSAL will skip a token cache read,
and will attempt to acquire a new token.
.. note::
Known issue: When an Azure VM has only one user-assigned managed identity,
Expand All @@ -255,8 +277,8 @@ def acquire_token_for_client(self, *, resource): # We may support scope in the
access_token_from_cache = None
client_id_in_cache = self._managed_identity.get(
ManagedIdentity.ID, "SYSTEM_ASSIGNED_MANAGED_IDENTITY")
if True: # Does not offer an "if not force_refresh" option, because
# there would be built-in token cache in the service side anyway
now = time.time()
if not claims_challenge: # Then attempt token cache search
matches = self._token_cache.find(
self._token_cache.CredentialType.ACCESS_TOKEN,
target=[resource],
Expand All @@ -267,7 +289,6 @@ def acquire_token_for_client(self, *, resource): # We may support scope in the
home_account_id=None,
),
)
now = time.time()
for entry in matches:
expires_in = int(entry["expires_on"]) - now
if expires_in < 5*60: # Then consider it expired
Expand All @@ -277,6 +298,7 @@ def acquire_token_for_client(self, *, resource): # We may support scope in the
"access_token": entry["secret"],
"token_type": entry.get("token_type", "Bearer"),
"expires_in": int(expires_in), # OAuth2 specs defines it as int
self._TOKEN_SOURCE: self._TOKEN_SOURCE_CACHE,
}
if "refresh_on" in entry:
access_token_from_cache["refresh_on"] = int(entry["refresh_on"])
Expand All @@ -300,6 +322,7 @@ def acquire_token_for_client(self, *, resource): # We may support scope in the
))
if "refresh_in" in result:
result["refresh_on"] = int(now + result["refresh_in"])
result[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP
if (result and "error" not in result) or (not access_token_from_cache):
return result
except: # The exact HTTP exception is transportation-layer dependent
Expand Down
15 changes: 8 additions & 7 deletions tests/test_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,20 +82,17 @@ def _test_happy_path(self, app, mocked_http, expires_in, resource="R"):
self.assertTrue(
is_subdict_of(expected_result, result), # We will test refresh_on later
"Should obtain a token response")
self.assertTrue(result["token_source"], "identity_provider")
self.assertEqual(expires_in, result["expires_in"], "Should have expected expires_in")
if expires_in >= 7200:
expected_refresh_on = int(time.time() + expires_in / 2)
self.assertTrue(
expected_refresh_on - 1 <= result["refresh_on"] <= expected_refresh_on + 1,
"Should have a refresh_on time around the middle of the token's life")
self.assertEqual(
result["access_token"],
app.acquire_token_for_client(resource=resource).get("access_token"),
"Should hit the same token from cache")

self.assertCacheStatus(app)

result = app.acquire_token_for_client(resource=resource)
self.assertCacheStatus(app)
self.assertEqual("cache", result["token_source"], "Should hit cache")
self.assertEqual(
call_count, mocked_http.call_count,
"No new call to the mocked http should be made for a cache hit")
Expand All @@ -110,6 +107,9 @@ def _test_happy_path(self, app, mocked_http, expires_in, resource="R"):
expected_refresh_on - 5 < result["refresh_on"] <= expected_refresh_on,
"Should have a refresh_on time around the middle of the token's life")

result = app.acquire_token_for_client(resource=resource, claims_challenge="foo")
self.assertEqual("identity_provider", result["token_source"], "Should miss cache")


class VmTestCase(ClientTestCase):

Expand Down Expand Up @@ -249,7 +249,8 @@ def test_happy_path(self, mocked_stat):
status_code=200,
text='{"access_token": "AT", "expires_in": "%s", "resource": "R"}' % expires_in,
),
]) as mocked_method:
] * 2, # Duplicate a pair of mocks for _test_happy_path()'s CAE check
) as mocked_method:
try:
self._test_happy_path(self.app, mocked_method, expires_in)
mocked_stat.assert_called_with(os.path.join(
Expand Down

0 comments on commit f0833ee

Please sign in to comment.