-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* chore upgrade `jsonschema` to 4.20.0 * chore upgrade `textx` to 4.0.1 * chore upgrade `restrictedpython` to 7.0 * chore upgrade `faker` to 20.1.0 * add new package `msal` to 1.26.0 * Add azure auth support * update badssl certificates
- Loading branch information
Showing
17 changed files
with
764 additions
and
210 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = '0.0.43a1' | ||
__version__ = '0.0.43a2' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
import json | ||
import msal | ||
import os | ||
import pickle | ||
import subprocess | ||
import time | ||
import logging | ||
|
||
from requests.auth import AuthBase | ||
from requests.models import PreparedRequest | ||
|
||
from cryptography import x509 | ||
from cryptography.hazmat.backends import default_backend | ||
from cryptography.hazmat.primitives import hashes | ||
from cryptography.hazmat.primitives import serialization | ||
from cryptography.hazmat.primitives.serialization import pkcs12 | ||
|
||
from .exceptions import DothttpAzureAuthException | ||
from .parse_models import AzureAuthWrap, AzureAuthType, AzureAuthSP, AzureAuthCertificate | ||
|
||
AZURE_CLI_TOKEN_STORE_PATH = os.path.expanduser('~/.dothttp.azure-cli.pkl') | ||
|
||
AZURE_SP_TOKEN_STORE_PATH = os.path.expanduser( | ||
'~/.dothttp.msal_token_cache.pkl') | ||
|
||
request_logger = logging.getLogger("request") | ||
|
||
|
||
def load_private_key_and_thumbprint(cert_path, password=None): | ||
extension = os.path.splitext(cert_path)[1].lower() | ||
with open(cert_path, "rb") as cert_file: | ||
cert_data = cert_file.read() | ||
|
||
if extension == '.pem': | ||
private_key = serialization.load_pem_private_key( | ||
cert_data, password, default_backend()) | ||
cert = x509.load_pem_x509_certificate(cert_data, default_backend()) | ||
elif extension == '.cer': | ||
cert = x509.load_der_x509_certificate(cert_data, default_backend()) | ||
private_key = None # .cer files do not contain private key | ||
elif extension == '.pfx' or extension == '.p12': | ||
private_key, cert, _ = pkcs12.load_key_and_certificates( | ||
cert_data, password, default_backend()) | ||
else: | ||
raise ValueError(f"Unsupported certificate format {extension}") | ||
|
||
if private_key is not None: | ||
private_key_bytes = private_key.private_bytes( | ||
encoding=serialization.Encoding.PEM, | ||
format=serialization.PrivateFormat.PKCS8, | ||
encryption_algorithm=serialization.NoEncryption() | ||
) | ||
else: | ||
private_key_bytes = None | ||
|
||
thumbprint = cert.fingerprint(hashes.SHA1()).hex() | ||
|
||
return private_key_bytes, thumbprint | ||
|
||
class AzureAuth(AuthBase): | ||
|
||
def __init__(self, azure_auth_wrap: AzureAuthWrap): | ||
self.azure_auth_wrap = azure_auth_wrap | ||
self.token_cache = msal.SerializableTokenCache() | ||
try: | ||
# Try to load the token cache from a file | ||
with open(AZURE_SP_TOKEN_STORE_PATH, 'rb') as token_cache_file: | ||
self.token_cache.deserialize( | ||
json.dumps(pickle.load(token_cache_file))) | ||
except FileNotFoundError: | ||
# If the file does not exist, initialize a new token cache | ||
pass | ||
|
||
def __call__(self, r: PreparedRequest) -> PreparedRequest: | ||
if self.azure_auth_wrap.azure_auth_type == AzureAuthType.SERVICE_PRINCIPAL: | ||
self.acquire_token_silently_or_ondemand( | ||
r, self.azure_auth_wrap.azure_spsecret_auth) | ||
self.save_token_cache() | ||
elif self.azure_auth_wrap.azure_auth_type == AzureAuthType.CERTIFICATE: | ||
self.acquire_token_silently_or_ondemand( | ||
r, self.azure_auth_wrap.azure_spcert_auth) | ||
self.save_token_cache() | ||
# For device code and cli authentication, we use the access token directly | ||
# in future we can use msal to get the access token for device code | ||
elif self.azure_auth_wrap.azure_auth_type in [AzureAuthType.CLI, AzureAuthType.DEVICE_CODE]: | ||
access_token = None | ||
expires_on = None | ||
# Try to load the access token and its expiry time from a file | ||
scope = self.azure_auth_wrap.azure_cli_auth.scope if self.azure_auth_wrap.azure_cli_auth else self.azure_auth_wrap.azure_device_code.scope | ||
|
||
if os.path.exists(AZURE_CLI_TOKEN_STORE_PATH): | ||
request_logger.debug("azure cli token already exists, using") | ||
with open(AZURE_CLI_TOKEN_STORE_PATH, 'rb') as token_file: | ||
data = pickle.load(token_file) | ||
scope_wise_store = data.get(scope, {}) | ||
access_token = scope_wise_store.get('access_token', None) | ||
expires_on = scope_wise_store.get('expires_on', None) | ||
# Get the current time in seconds since the Epoch | ||
current_time = time.time() | ||
|
||
|
||
# If the file does not exist or the token has expired, get a new access token | ||
if not access_token or not expires_on or current_time >= expires_on: | ||
request_logger.debug( | ||
"azure cli token store cached not availabile or expired") | ||
# get token from cli by invoking az account get-access-token | ||
result = subprocess.run( | ||
["az", "account", "get-access-token", "--scope", scope], capture_output=True, text=True) | ||
result_json = json.loads(result.stdout) | ||
access_token = result_json['accessToken'] | ||
# Convert the expiresOn field to seconds since the Epoch | ||
expires_on = time.mktime(time.strptime(result_json['expiresOn'], '%Y-%m-%d %H:%M:%S.%f')) | ||
# Save the new access token and its expiry time to the file | ||
with open(AZURE_CLI_TOKEN_STORE_PATH, 'wb') as token_file: | ||
scope_wise_store = dict() | ||
scope_wise_store[scope] = { | ||
'access_token': access_token, 'expires_on': expires_on} | ||
pickle.dump(scope_wise_store, token_file) | ||
request_logger.debug( | ||
"computed or fetched azure cli token access bearer token and appeneded") | ||
r.headers["Authorization"] = f"Bearer {access_token}" | ||
return r | ||
|
||
def acquire_token_silently_or_ondemand(self, r, auth_wrap: AzureAuthSP): | ||
kwargs = { | ||
"client_id": auth_wrap.client_id, | ||
"authority": f"https://login.microsoftonline.com/{auth_wrap.tenant_id}", | ||
"token_cache": self.token_cache | ||
} | ||
if isinstance(auth_wrap, AzureAuthCertificate): | ||
try: | ||
private_key_bytes, thumbprint = load_private_key_and_thumbprint( | ||
auth_wrap.certificate_path, auth_wrap.certificate_password) | ||
kwargs["client_credential"] = { | ||
"private_key": private_key_bytes, | ||
"thumbprint": thumbprint | ||
} | ||
except Exception as e: | ||
request_logger.error( | ||
"loading private key failed with error", e) | ||
raise DothttpAzureAuthException(message=str(e)) | ||
else: | ||
kwargs["client_credential"] = auth_wrap.client_secret | ||
app = self.create_confidential_app(kwargs) | ||
accounts = app.get_accounts() | ||
if accounts: | ||
result = app.acquire_token_silent(scopes=[auth_wrap.scope], account=accounts[0]) | ||
if not accounts or "access_token" not in result: | ||
result = app.acquire_token_for_client(scopes=[auth_wrap.scope]) | ||
r.headers["Authorization"] = f"Bearer {result['access_token']}" | ||
|
||
def create_confidential_app(self, kwargs): | ||
return msal.ConfidentialClientApplication(**kwargs) | ||
|
||
def save_token_cache(self): | ||
with open(AZURE_SP_TOKEN_STORE_PATH, 'wb') as token_cache_file: | ||
pickle.dump(json.loads(self.token_cache.serialize()), | ||
token_cache_file) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.