Skip to content

Commit

Permalink
Merge pull request #412 from AzureAD/release-1.15.0
Browse files Browse the repository at this point in the history
Release 1.15.0
  • Loading branch information
rayluo authored Sep 30, 2021
2 parents be55e2b + 7d1c16d commit 8a4cdea
Show file tree
Hide file tree
Showing 8 changed files with 215 additions and 25 deletions.
76 changes: 69 additions & 7 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import requests

from .oauth2cli import Client, JwtAssertionCreator
from .oauth2cli.oidc import decode_part
from .authority import Authority
from .mex import send_request as mex_send_request
from .wstrust_request import send_request as wst_send_request
Expand All @@ -25,7 +26,7 @@


# The __init__.py will import this. Not the other way around.
__version__ = "1.14.0"
__version__ = "1.15.0"

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -111,6 +112,36 @@ def _preferred_browser():
return None


class _ClientWithCcsRoutingInfo(Client):

def initiate_auth_code_flow(self, **kwargs):
if kwargs.get("login_hint"): # eSTS could have utilized this as-is, but nope
kwargs["X-AnchorMailbox"] = "UPN:%s" % kwargs["login_hint"]
return super(_ClientWithCcsRoutingInfo, self).initiate_auth_code_flow(
client_info=1, # To be used as CSS Routing info
**kwargs)

def obtain_token_by_auth_code_flow(
self, auth_code_flow, auth_response, **kwargs):
# Note: the obtain_token_by_browser() is also covered by this
assert isinstance(auth_code_flow, dict) and isinstance(auth_response, dict)
headers = kwargs.pop("headers", {})
client_info = json.loads(
decode_part(auth_response["client_info"])
) if auth_response.get("client_info") else {}
if "uid" in client_info and "utid" in client_info:
# Note: The value of X-AnchorMailbox is also case-insensitive
headers["X-AnchorMailbox"] = "Oid:{uid}@{utid}".format(**client_info)
return super(_ClientWithCcsRoutingInfo, self).obtain_token_by_auth_code_flow(
auth_code_flow, auth_response, headers=headers, **kwargs)

def obtain_token_by_username_password(self, username, password, **kwargs):
headers = kwargs.pop("headers", {})
headers["X-AnchorMailbox"] = "upn:{}".format(username)
return super(_ClientWithCcsRoutingInfo, self).obtain_token_by_username_password(
username, password, headers=headers, **kwargs)


class ClientApplication(object):

ACQUIRE_TOKEN_SILENT_ID = "84"
Expand Down Expand Up @@ -174,7 +205,7 @@ def __init__(
you may try use only the leaf cert (in PEM/str format) instead.
*Added in version 1.13.0*:
It can also be a completly pre-signed assertion that you've assembled yourself.
It can also be a completely pre-signed assertion that you've assembled yourself.
Simply pass a container containing only the key "client_assertion", like this::
{
Expand Down Expand Up @@ -481,7 +512,7 @@ def _build_client(self, client_credential, authority, skip_regional_client=False
authority.device_authorization_endpoint or
urljoin(authority.token_endpoint, "devicecode"),
}
central_client = Client(
central_client = _ClientWithCcsRoutingInfo(
central_configuration,
self.client_id,
http_client=self.http_client,
Expand All @@ -506,7 +537,7 @@ def _build_client(self, client_credential, authority, skip_regional_client=False
regional_authority.device_authorization_endpoint or
urljoin(regional_authority.token_endpoint, "devicecode"),
}
regional_client = Client(
regional_client = _ClientWithCcsRoutingInfo(
regional_configuration,
self.client_id,
http_client=self.http_client,
Expand All @@ -529,6 +560,7 @@ def initiate_auth_code_flow(
login_hint=None, # type: Optional[str]
domain_hint=None, # type: Optional[str]
claims_challenge=None,
max_age=None,
):
"""Initiate an auth code flow.
Expand Down Expand Up @@ -559,6 +591,17 @@ def initiate_auth_code_flow(
`here <https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-auth-code-flow#request-an-authorization-code>`_ and
`here <https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-oapx/86fb452d-e34a-494e-ac61-e526e263b6d8>`_.
:param int max_age:
OPTIONAL. Maximum Authentication Age.
Specifies the allowable elapsed time in seconds
since the last time the End-User was actively authenticated.
If the elapsed time is greater than this value,
Microsoft identity platform will actively re-authenticate the End-User.
MSAL Python will also automatically validate the auth_time in ID token.
New in version 1.15.
:return:
The auth code flow. It is a dict in this form::
Expand All @@ -577,7 +620,7 @@ def initiate_auth_code_flow(
3. and then relay this dict and subsequent auth response to
:func:`~acquire_token_by_auth_code_flow()`.
"""
client = Client(
client = _ClientWithCcsRoutingInfo(
{"authorization_endpoint": self.authority.authorization_endpoint},
self.client_id,
http_client=self.http_client)
Expand All @@ -588,6 +631,7 @@ def initiate_auth_code_flow(
domain_hint=domain_hint,
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge),
max_age=max_age,
)
flow["claims_challenge"] = claims_challenge
return flow
Expand Down Expand Up @@ -654,7 +698,7 @@ def get_authorization_request_url(
self.http_client
) if authority else self.authority

client = Client(
client = _ClientWithCcsRoutingInfo(
{"authorization_endpoint": the_authority.authorization_endpoint},
self.client_id,
http_client=self.http_client)
Expand Down Expand Up @@ -1178,6 +1222,10 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
key=lambda e: int(e.get("last_modification_time", "0")),
reverse=True):
logger.debug("Cache attempts an RT")
headers = telemetry_context.generate_headers()
if "home_account_id" in query: # Then use it as CCS Routing info
headers["X-AnchorMailbox"] = "Oid:{}".format( # case-insensitive value
query["home_account_id"].replace(".", "@"))
response = client.obtain_token_by_refresh_token(
entry, rt_getter=lambda token_item: token_item["secret"],
on_removing_rt=lambda rt_item: None, # Disable RT removal,
Expand All @@ -1189,7 +1237,7 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
skip_account_creation=True, # To honor a concurrent remove_account()
)),
scope=scopes,
headers=telemetry_context.generate_headers(),
headers=headers,
data=dict(
kwargs.pop("data", {}),
claims=_merge_claims_challenge_and_capabilities(
Expand Down Expand Up @@ -1370,6 +1418,7 @@ def acquire_token_interactive(
timeout=None,
port=None,
extra_scopes_to_consent=None,
max_age=None,
**kwargs):
"""Acquire token interactively i.e. via a local browser.
Expand Down Expand Up @@ -1415,6 +1464,17 @@ def acquire_token_interactive(
in the same interaction, but for which you won't get back a
token for in this particular operation.
:param int max_age:
OPTIONAL. Maximum Authentication Age.
Specifies the allowable elapsed time in seconds
since the last time the End-User was actively authenticated.
If the elapsed time is greater than this value,
Microsoft identity platform will actively re-authenticate the End-User.
MSAL Python will also automatically validate the auth_time in ID token.
New in version 1.15.
:return:
- A dict containing no "error" key,
and typically contains an "access_token" key.
Expand All @@ -1433,6 +1493,7 @@ def acquire_token_interactive(
port=port or 0),
prompt=prompt,
login_hint=login_hint,
max_age=max_age,
timeout=timeout,
auth_params={
"claims": claims,
Expand Down Expand Up @@ -1581,6 +1642,7 @@ def acquire_token_on_behalf_of(self, user_assertion, scopes, claims_challenge=No
claims=_merge_claims_challenge_and_capabilities(
self._client_capabilities, claims_challenge)),
headers=telemetry_context.generate_headers(),
# TBD: Expose a login_hint (or ccs_routing_hint) param for web app
**kwargs))
telemetry_context.update_telemetry(response)
return response
46 changes: 39 additions & 7 deletions msal/oauth2cli/authcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import logging
import socket
from string import Template
import threading
import time

try: # Python 3
from http.server import HTTPServer, BaseHTTPRequestHandler
Expand Down Expand Up @@ -143,17 +145,14 @@ def __init__(self, port=None):
# TODO: But, it would treat "localhost" or "" as IPv4.
# If pressed, we might just expose a family parameter to caller.
self._server = Server((address, port or 0), _AuthCodeHandler)
self._closing = False

def get_port(self):
"""The port this server actually listening to"""
# https://docs.python.org/2.7/library/socketserver.html#SocketServer.BaseServer.server_address
return self._server.server_address[1]

def get_auth_response(self, auth_uri=None, timeout=None, state=None,
welcome_template=None, success_template=None, error_template=None,
auth_uri_callback=None,
browser_name=None,
):
def get_auth_response(self, timeout=None, **kwargs):
"""Wait and return the auth response. Raise RuntimeError when timeout.
:param str auth_uri:
Expand Down Expand Up @@ -192,6 +191,37 @@ def get_auth_response(self, auth_uri=None, timeout=None, state=None,
and https://openid.net/specs/openid-connect-core-1_0.html#AuthResponse
Returns None when the state was mismatched, or when timeout occurred.
"""
# Historically, the _get_auth_response() uses HTTPServer.handle_request(),
# because its handle-and-retry logic is conceptually as easy as a while loop.
# Also, handle_request() honors server.timeout setting, and CTRL+C simply works.
# All those are true when running on Linux.
#
# However, the behaviors on Windows turns out to be different.
# A socket server waiting for request would freeze the current thread.
# Neither timeout nor CTRL+C would work. End user would have to do CTRL+BREAK.
# https://stackoverflow.com/questions/1364173/stopping-python-using-ctrlc
#
# The solution would need to somehow put the http server into its own thread.
# This could be done by the pattern of ``http.server.test()`` which internally
# use ``ThreadingHTTPServer.serve_forever()`` (only available in Python 3.7).
# Or create our own thread to wrap the HTTPServer.handle_request() inside.
result = {} # A mutable object to be filled with thread's return value
t = threading.Thread(
target=self._get_auth_response, args=(result,), kwargs=kwargs)
t.daemon = True # So that it won't prevent the main thread from exiting
t.start()
begin = time.time()
while (time.time() - begin < timeout) if timeout else True:
time.sleep(1) # Short detection interval to make happy path responsive
if not t.is_alive(): # Then the thread has finished its job and exited
break
return result or None

def _get_auth_response(self, result, auth_uri=None, timeout=None, state=None,
welcome_template=None, success_template=None, error_template=None,
auth_uri_callback=None,
browser_name=None,
):
welcome_uri = "http://localhost:{p}".format(p=self.get_port())
abort_uri = "{loc}?error=abort".format(loc=welcome_uri)
logger.debug("Abort by visit %s", abort_uri)
Expand Down Expand Up @@ -229,7 +259,8 @@ def get_auth_response(self, auth_uri=None, timeout=None, state=None,

self._server.timeout = timeout # Otherwise its handle_timeout() won't work
self._server.auth_response = {} # Shared with _AuthCodeHandler
while True:
while not self._closing: # Otherwise, the handle_request() attempt
# would yield noisy ValueError trace
# Derived from
# https://docs.python.org/2/library/basehttpserver.html#more-examples
self._server.handle_request()
Expand All @@ -238,10 +269,11 @@ def get_auth_response(self, auth_uri=None, timeout=None, state=None,
logger.debug("State mismatch. Ignoring this noise.")
else:
break
return self._server.auth_response
result.update(self._server.auth_response) # Return via writable result param

def close(self):
"""Either call this eventually; or use the entire class as context manager"""
self._closing = True
self._server.server_close()

def __enter__(self):
Expand Down
2 changes: 1 addition & 1 deletion msal/oauth2cli/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749
_data["client_assertion"] = encoder(
self.client_assertion() # Do lazy on-the-fly computation
if callable(self.client_assertion) else self.client_assertion
) # The type is bytes, which is preferrable. See also:
) # The type is bytes, which is preferable. See also:
# https://github.com/psf/requests/issues/4503#issuecomment-455001070

_data.update(self.default_body) # It may contain authen parameters
Expand Down
30 changes: 26 additions & 4 deletions msal/oauth2cli/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def decode_id_token(id_token, client_id=None, issuer=None, nonce=None, now=None)
"""
decoded = json.loads(decode_part(id_token.split('.')[1]))
err = None # https://openid.net/specs/openid-connect-core-1_0.html#IDTokenValidation
_now = now or time.time()
_now = int(now or time.time())
skew = 120 # 2 minutes
if _now + skew < decoded.get("nbf", _now - 1): # nbf is optional per JWT specs
# This is not an ID token validation, but a JWT validation
Expand All @@ -67,14 +67,14 @@ def decode_id_token(id_token, client_id=None, issuer=None, nonce=None, now=None)
# the Client and the Token Endpoint (which it is during _obtain_token()),
# the TLS server validation MAY be used to validate the issuer
# in place of checking the token signature.
if _now > decoded["exp"]:
if _now - skew > decoded["exp"]:
err = "9. The current time MUST be before the time represented by the exp Claim."
if nonce and nonce != decoded.get("nonce"):
err = ("11. Nonce must be the same value "
"as the one that was sent in the Authentication Request.")
if err:
raise RuntimeError("%s The id_token was: %s" % (
err, json.dumps(decoded, indent=2)))
raise RuntimeError("%s Current epoch = %s. The id_token was: %s" % (
err, _now, json.dumps(decoded, indent=2)))
return decoded


Expand Down Expand Up @@ -187,6 +187,8 @@ def initiate_auth_code_flow(
flow = super(Client, self).initiate_auth_code_flow(
scope=_scope, nonce=_nonce_hash(nonce), **kwargs)
flow["nonce"] = nonce
if kwargs.get("max_age") is not None:
flow["max_age"] = kwargs["max_age"]
return flow

def obtain_token_by_auth_code_flow(self, auth_code_flow, auth_response, **kwargs):
Expand All @@ -208,6 +210,26 @@ def obtain_token_by_auth_code_flow(self, auth_code_flow, auth_response, **kwargs
raise RuntimeError(
'The nonce in id token ("%s") should match our nonce ("%s")' %
(nonce_in_id_token, expected_hash))

if auth_code_flow.get("max_age") is not None:
auth_time = result.get("id_token_claims", {}).get("auth_time")
if not auth_time:
raise RuntimeError(
"13. max_age was requested, ID token should contain auth_time")
now = int(time.time())
skew = 120 # 2 minutes. Hardcoded, for now
if now - skew > auth_time + auth_code_flow["max_age"]:
raise RuntimeError(
"13. auth_time ({auth_time}) was requested, "
"by using max_age ({max_age}) parameter, "
"and now ({now}) too much time has elasped "
"since last end-user authentication. "
"The ID token was: {id_token}".format(
auth_time=auth_time,
max_age=auth_code_flow["max_age"],
now=now,
id_token=json.dumps(result["id_token_claims"], indent=2),
))
return result

def obtain_token_by_browser(
Expand Down
7 changes: 4 additions & 3 deletions msal/throttled_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,17 @@ def __init__(self, http_client, http_cache):
# acquire_token_silent(..., force_refresh=True) pattern.
str(kwargs.get("params")) + str(kwargs.get("data"))),
),
expires_in=lambda result=None, data=None, **ignored:
expires_in=lambda result=None, kwargs=None, **ignored:
60
if result.status_code == 400
# Here we choose to cache exact HTTP 400 errors only (rather than 4xx)
# because they are the ones defined in OAuth2
# (https://datatracker.ietf.org/doc/html/rfc6749#section-5.2)
# Other 4xx errors might have different requirements e.g.
# "407 Proxy auth required" would need a key including http headers.
and not( # Exclude Device Flow cause its retry is expected and regulated
isinstance(data, dict) and data.get("grant_type") == DEVICE_AUTH_GRANT
and not( # Exclude Device Flow whose retry is expected and regulated
isinstance(kwargs.get("data"), dict)
and kwargs["data"].get("grant_type") == DEVICE_AUTH_GRANT
)
and "retry-after" not in set( # Leave it to the Retry-After decorator
h.lower() for h in getattr(result, "headers", {}).keys())
Expand Down
Loading

0 comments on commit 8a4cdea

Please sign in to comment.