Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added optional OIDC version parameter to both OIDC versions #11

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions aad_token_verify/token_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,29 @@
from cryptography.hazmat.backends import default_backend
from cryptography.x509 import load_pem_x509_certificate
from jwt import decode, get_unverified_header
from jwt.exceptions import ExpiredSignatureError, InvalidAudienceError, InvalidIssuerError
from jwt.exceptions import (
ExpiredSignatureError,
InvalidAudienceError,
InvalidIssuerError,
)

from .exceptions import AADError, AuthorizationError, TokenParseError


def get_verified_payload(token: str, tenant_id: str = "common", audience_uris: List[str] = None) -> Dict[str, Any]:
def get_verified_payload(
token: str,
tenant_id: str = "common",
audience_uris: List[str] = None,
oidc_version: str = None,
) -> Dict[str, Any]:
"""Gets a verified token payload

Args:
token (str): The token to verify
tenant_id (str, optional): THe tent id of the issuer. Defaults to "common".
tenant_id (str, optional): Tenant id of the issuer. Defaults to "common".
audience_uris (List[str], optional): The audience uris of the token. Defaults to None.
oidc_version: (str, optional): AAD OIDC version implementation used to retrieve the well known public keys.
The only other value accepted is "v2.0".

Raises:
AuthorizationError: If the token is expired
Expand All @@ -26,9 +37,10 @@ def get_verified_payload(token: str, tenant_id: str = "common", audience_uris: L
Returns:
Dict[str, Any]: The verified token paylod
"""
oidc_version = oidc_version if oidc_version == "v2.0" else ""
kid = _get_kid_from_token_header(token)
public_key = _get_public_key(kid, tenant_id)
openid_config = _get_openid_config(tenant_id)
public_key = _get_public_key(kid, tenant_id, oidc_version)
openid_config = _get_openid_config(tenant_id, oidc_version)
try:
payload = decode(
token,
Expand Down Expand Up @@ -75,7 +87,7 @@ def _get_kid_from_token_header(token: str) -> str:
return unverified_token_header.get("kid")


def _get_public_key(kid: str, tenant_id: str):
def _get_public_key(kid: str, tenant_id: str, oidc_version: str):
"""Retrieves the tenant public key using a token's KID

Args:
Expand All @@ -88,7 +100,7 @@ def _get_public_key(kid: str, tenant_id: str):
try:
x5c: List[str] = []
# Iterate JWK keys and extract matching x5c chain
for key in _get_jwk_keys(tenant_id):
for key in _get_jwk_keys(tenant_id, oidc_version):
if key["kid"] == kid:
x5c = key["x5c"]

Expand All @@ -105,11 +117,12 @@ def _get_public_key(kid: str, tenant_id: str):


@cached(cache=TTLCache(maxsize=16, ttl=3600))
def _get_jwk_keys(tenant_id: str) -> List[Dict]:
def _get_jwk_keys(tenant_id: str, oidc_version: str) -> List[Dict]:
"""Retrieves the JWK keys for a specified issuer

Args:
tenant_id (str): The tenant id of the issuer
oidc_version(str): OIDC version for the well known endpoint

Raises:
AADError: If the jwk_uri is not in the OpenID Config
Expand All @@ -118,7 +131,7 @@ def _get_jwk_keys(tenant_id: str) -> List[Dict]:
Returns:
List[Dict]: List of the jwk keys
"""
jwks_uri = _get_openid_config(tenant_id).get("jwks_uri")
jwks_uri = _get_openid_config(tenant_id, oidc_version).get("jwks_uri")
if not jwks_uri:
raise AADError("jwks_uri not in OpenID Config")

Expand All @@ -133,15 +146,18 @@ def _get_jwk_keys(tenant_id: str) -> List[Dict]:


@cached(cache=TTLCache(maxsize=16, ttl=3600))
def _get_openid_config(tenant_id: str) -> Dict[str, Any]:
def _get_openid_config(tenant_id: str, oidc_version: str) -> Dict[str, Any]:
"""Retrieves the OpenID config for a specified issuer

Args:
tenant_id (str): The tenant id of the issuer
oidc_version(str): OIDC version for the well known endpoint

Returns:
Dict[str, Any]: The OpenID config
"""
oidc_response = requests.get(f"https://login.microsoftonline.com/{tenant_id}/.well-known/openid-configuration")
oidc_response = requests.get(
f"https://login.microsoftonline.com/{tenant_id}/{oidc_version}/.well-known/openid-configuration"
)
oidc_response.raise_for_status()
return oidc_response.json()
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@

setup(
name='aad-token-verify',
version='0.2.1',
version='0.3.0',
description='A python utility library to verify an Azure Active Directory OAuth token',
long_description=long_description,
long_description_content_type='text/markdown',
url='https://github.com/GeneralMills/azure-ad-token-verify',
author=['Daniel Thompson'],
author_email='daniel.thompson1@genmills.com',
author_email='daniel.thompson2@genmills.com ',
license='MIT',
classifiers=[
'Development Status :: 4 - Beta',
Expand All @@ -32,7 +32,7 @@
install_requires=[
'requests>=2.25.1,<3',
'PyJWT>=2.1.0,<3',
'cryptography>=42.0.0,<43',
'cryptography>=41.0.1,<42',
'cachetools>=5.3.1,<6'
],
keywords='azure ad token oauth verify jwt',
Expand Down