Skip to content

Commit

Permalink
Azure Auth Support (#209)
Browse files Browse the repository at this point in the history
* chore upgrade `jsonschema` to 4.20.0
* chore upgrade `textx` to 4.0.1
* chore upgrade `restrictedpython` to 7.0
* chore upgrade `faker` to 20.1.0
* add new package `msal` to 1.26.0
* Add azure auth support 
* update badssl certificates
  • Loading branch information
cedric05 authored Dec 13, 2023
1 parent 109d1e6 commit 666a2b4
Show file tree
Hide file tree
Showing 17 changed files with 764 additions and 210 deletions.
11 changes: 6 additions & 5 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,24 @@ verify_ssl = true
name = "pypi"

[packages]
jsonschema = "==4.19.2"
jsonschema = "==4.20.0"
jstyleson = "==0.0.2"
requests = "==2.31.0"
textx = "==3.1.1"
textx = "==4.0.1"
js2py = "==0.74"
requests-pkcs12 = "==1.22"
parsys-requests-unixsocket = "==0.3.1"
requests-aws4auth = "==1.2.3"
requests-ntlm = "==1.2.0"
restrictedpython = "==6.2"
faker = "==20.0.0"
restrictedpython = "==7.0"
faker = "==20.1.0"
requests-hawk = "==1.2.1"
pyyaml = "==6.0.1"
toml = "==0.10.2"
python-magic = "*"
msal = "==1.26.0"

[dev-packages]
python-magic = "*"
waitress = "==2.1.1"
flask = "*"
pyperf = "==2.6.2"
Expand Down
342 changes: 180 additions & 162 deletions Pipfile.lock

Large diffs are not rendered by default.

47 changes: 44 additions & 3 deletions dothttp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,17 @@

from .dsl_jsonparser import json_or_array_to_json
from .exceptions import *
from .parse_models import MultidefHttp, AuthWrap, DigestAuth, BasicAuth, Line, NtlmAuthWrap, Query, Http, NameWrap, UrlWrap, Header, \
from .parse_models import AzureAuthCli, AzureAuthType, AzureAuthWrap, MultidefHttp, AuthWrap, DigestAuth, BasicAuth, Line, NtlmAuthWrap, Query, Http, NameWrap, UrlWrap, Header, \
MultiPartFile, FilesWrap, TripleOrDouble, Payload as ParsePayload, Certificate, P12Certificate, ExtraArg, \
AWS_REGION_LIST, AWS_SERVICES_LIST, AwsAuthWrap, TestScript, ScriptType, HawkAuth
AWS_REGION_LIST, AWS_SERVICES_LIST, AwsAuthWrap, TestScript, ScriptType, HawkAuth, AzureAuthCertificate, \
AzureAuthDeviceCode, AzureAuthServicePrincipal
from .property_schema import property_schema
from .property_util import PropertyProvider
try:
from .azure_auth import AzureAuth
except:
# this is for dothttp-wasm, where msal most likely not installed
AzureAuth = None

try:
import magic
Expand Down Expand Up @@ -320,6 +326,9 @@ def get_http_from_req(self):
aws_auth.signing_key.secret_key,
aws_auth.service,
aws_auth.region))
elif isinstance(self.auth, AzureAuth):
auth_wrap = AuthWrap(
azure_auth=self.auth.azure_auth_wrap)
certificate = None
if self.certificate:
certificate = Certificate(*self.certificate)
Expand Down Expand Up @@ -1051,12 +1060,44 @@ def load_auth(self):
aws_service
)
else:
# region and aws_service can be extracted from url
# aws service and region can be extracted from url
# somehow library is not supporting those
# with current state, we are not support this use case
# we may come back
# all four parameters are required and are to be non empty
raise DothttpAwsAuthException(access_id=access_id)
elif azure_auth := auth_wrap.azure_auth:
azure_auth: AzureAuthWrap = azure_auth
if sp_auth := azure_auth.azure_spsecret_auth:
azure_auth_wrap = AzureAuthWrap(azure_spsecret_auth=AzureAuthServicePrincipal(
tenant_id=self.get_updated_content(sp_auth.tenant_id),
client_id=self.get_updated_content(sp_auth.client_id),
client_secret=self.get_updated_content(sp_auth.client_secret),
scope = self.get_updated_content(sp_auth.scope or "https://management.azure.com/.default")
), azure_auth_type=AzureAuthType.SERVICE_PRINCIPAL)
elif cert_auth := azure_auth.azure_spcert_auth:
azure_auth_wrap = AzureAuthWrap(azure_spcert_auth=AzureAuthCertificate(
tenant_id=self.get_updated_content(cert_auth.tenant_id),
client_id=self.get_updated_content(cert_auth.client_id),
certificate_path=self.get_updated_content(cert_auth.certificate_path),
scope=self.get_updated_content(cert_auth.scope or "https://management.azure.com/.default")
), azure_auth_type=AzureAuthType.CERTIFICATE)
elif azure_auth.azure_cli_auth:
azure_auth_wrap = AzureAuthWrap(
azure_cli_auth=AzureAuthCli(
scope=self.get_updated_content(
azure_auth.azure_cli_auth.scope or "https://management.azure.com/.default"
)
), azure_auth_type=AzureAuthType.CLI)
elif azure_auth.auth.azure_device_code:
azure_auth_wrap = AzureAuthWrap(
azure_device_code=AzureAuthDeviceCode(
scope=self.get_updated_content(
azure_auth.azure_device_code.scope or "https://management.azure.com/.default"
)
), azure_auth_type=AzureAuthType.DEVICE_CODE)

self.httpdef.auth = AzureAuth(azure_auth_wrap)

def get_current_or_base(self, attr_key) -> Any:
if getattr(self.http, attr_key):
Expand Down
2 changes: 1 addition & 1 deletion dothttp/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.43a1'
__version__ = '0.0.43a2'
158 changes: 158 additions & 0 deletions dothttp/azure_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import json
import msal
import os
import pickle
import subprocess
import time
import logging

from requests.auth import AuthBase
from requests.models import PreparedRequest

from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.serialization import pkcs12

from .exceptions import DothttpAzureAuthException
from .parse_models import AzureAuthWrap, AzureAuthType, AzureAuthSP, AzureAuthCertificate

AZURE_CLI_TOKEN_STORE_PATH = os.path.expanduser('~/.dothttp.azure-cli.pkl')

AZURE_SP_TOKEN_STORE_PATH = os.path.expanduser(
'~/.dothttp.msal_token_cache.pkl')

request_logger = logging.getLogger("request")


def load_private_key_and_thumbprint(cert_path, password=None):
extension = os.path.splitext(cert_path)[1].lower()
with open(cert_path, "rb") as cert_file:
cert_data = cert_file.read()

if extension == '.pem':
private_key = serialization.load_pem_private_key(
cert_data, password, default_backend())
cert = x509.load_pem_x509_certificate(cert_data, default_backend())
elif extension == '.cer':
cert = x509.load_der_x509_certificate(cert_data, default_backend())
private_key = None # .cer files do not contain private key
elif extension == '.pfx' or extension == '.p12':
private_key, cert, _ = pkcs12.load_key_and_certificates(
cert_data, password, default_backend())
else:
raise ValueError(f"Unsupported certificate format {extension}")

if private_key is not None:
private_key_bytes = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
)
else:
private_key_bytes = None

thumbprint = cert.fingerprint(hashes.SHA1()).hex()

return private_key_bytes, thumbprint

class AzureAuth(AuthBase):

def __init__(self, azure_auth_wrap: AzureAuthWrap):
self.azure_auth_wrap = azure_auth_wrap
self.token_cache = msal.SerializableTokenCache()
try:
# Try to load the token cache from a file
with open(AZURE_SP_TOKEN_STORE_PATH, 'rb') as token_cache_file:
self.token_cache.deserialize(
json.dumps(pickle.load(token_cache_file)))
except FileNotFoundError:
# If the file does not exist, initialize a new token cache
pass

def __call__(self, r: PreparedRequest) -> PreparedRequest:
if self.azure_auth_wrap.azure_auth_type == AzureAuthType.SERVICE_PRINCIPAL:
self.acquire_token_silently_or_ondemand(
r, self.azure_auth_wrap.azure_spsecret_auth)
self.save_token_cache()
elif self.azure_auth_wrap.azure_auth_type == AzureAuthType.CERTIFICATE:
self.acquire_token_silently_or_ondemand(
r, self.azure_auth_wrap.azure_spcert_auth)
self.save_token_cache()
# For device code and cli authentication, we use the access token directly
# in future we can use msal to get the access token for device code
elif self.azure_auth_wrap.azure_auth_type in [AzureAuthType.CLI, AzureAuthType.DEVICE_CODE]:
access_token = None
expires_on = None
# Try to load the access token and its expiry time from a file
scope = self.azure_auth_wrap.azure_cli_auth.scope if self.azure_auth_wrap.azure_cli_auth else self.azure_auth_wrap.azure_device_code.scope

if os.path.exists(AZURE_CLI_TOKEN_STORE_PATH):
request_logger.debug("azure cli token already exists, using")
with open(AZURE_CLI_TOKEN_STORE_PATH, 'rb') as token_file:
data = pickle.load(token_file)
scope_wise_store = data.get(scope, {})
access_token = scope_wise_store.get('access_token', None)
expires_on = scope_wise_store.get('expires_on', None)
# Get the current time in seconds since the Epoch
current_time = time.time()


# If the file does not exist or the token has expired, get a new access token
if not access_token or not expires_on or current_time >= expires_on:
request_logger.debug(
"azure cli token store cached not availabile or expired")
# get token from cli by invoking az account get-access-token
result = subprocess.run(
["az", "account", "get-access-token", "--scope", scope], capture_output=True, text=True)
result_json = json.loads(result.stdout)
access_token = result_json['accessToken']
# Convert the expiresOn field to seconds since the Epoch
expires_on = time.mktime(time.strptime(result_json['expiresOn'], '%Y-%m-%d %H:%M:%S.%f'))
# Save the new access token and its expiry time to the file
with open(AZURE_CLI_TOKEN_STORE_PATH, 'wb') as token_file:
scope_wise_store = dict()
scope_wise_store[scope] = {
'access_token': access_token, 'expires_on': expires_on}
pickle.dump(scope_wise_store, token_file)
request_logger.debug(
"computed or fetched azure cli token access bearer token and appeneded")
r.headers["Authorization"] = f"Bearer {access_token}"
return r

def acquire_token_silently_or_ondemand(self, r, auth_wrap: AzureAuthSP):
kwargs = {
"client_id": auth_wrap.client_id,
"authority": f"https://login.microsoftonline.com/{auth_wrap.tenant_id}",
"token_cache": self.token_cache
}
if isinstance(auth_wrap, AzureAuthCertificate):
try:
private_key_bytes, thumbprint = load_private_key_and_thumbprint(
auth_wrap.certificate_path, auth_wrap.certificate_password)
kwargs["client_credential"] = {
"private_key": private_key_bytes,
"thumbprint": thumbprint
}
except Exception as e:
request_logger.error(
"loading private key failed with error", e)
raise DothttpAzureAuthException(message=str(e))
else:
kwargs["client_credential"] = auth_wrap.client_secret
app = self.create_confidential_app(kwargs)
accounts = app.get_accounts()
if accounts:
result = app.acquire_token_silent(scopes=[auth_wrap.scope], account=accounts[0])
if not accounts or "access_token" not in result:
result = app.acquire_token_for_client(scopes=[auth_wrap.scope])
r.headers["Authorization"] = f"Bearer {result['access_token']}"

def create_confidential_app(self, kwargs):
return msal.ConfidentialClientApplication(**kwargs)

def save_token_cache(self):
with open(AZURE_SP_TOKEN_STORE_PATH, 'wb') as token_cache_file:
pickle.dump(json.loads(self.token_cache.serialize()),
token_cache_file)
6 changes: 6 additions & 0 deletions dothttp/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,9 @@ class ScriptException(DotHttpException):
"AWSAuth expects all(access_id, secret_token, region, service) to be non empty access_id:`{access_id}`")
class DothttpAwsAuthException(DotHttpException):
pass


@exception_wrapper(
"AzureAuth exception: {message}")
class DothttpAzureAuthException(DotHttpException):
pass
38 changes: 37 additions & 1 deletion dothttp/http.tx
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@ HEADER:
;

AUTHWRAP:
digest_auth = DIGESTAUTH | basic_auth = BASICAUTH | ntlm_auth = NTLMAUTH | hawk_auth=HAWKAUTH | aws_auth = AWSAUTH
digest_auth = DIGESTAUTH
| basic_auth = BASICAUTH
| ntlm_auth = NTLMAUTH
| hawk_auth=HAWKAUTH
| aws_auth = AWSAUTH
| azure_auth = AZUREAUTH
;

CERTAUTH:
Expand Down Expand Up @@ -80,6 +85,37 @@ AWSAUTH:
| 'awsauth' '(' ('access_id' '=' ) ? access_id=DotString ',' ('secret_key' '=')? secret_token=DotString ',' ('service' '=')? service=DotString (',' ('region' '=')? region=DotString)? ')'
;

AZUREAUTH:
azure_spcert_auth = AZURECERTIFICATEAUTH
// for service principal
| azure_spsecret_auth = AZURESERVICEPRINCIPALAUTH
// if you don't have both, use device_code flow
// in this case, we use azure cli to get auth_token
| azure_device_auth = AZUREDEVICEAUTH
| azure_cli_auth = AZURECLIAUTH
;

AZURESERVICEPRINCIPALAUTH:

'azurespsecret' '(' ('tenant_id' '=')? tenant_id=DotString ',' ('client_id' '=')? client_id=DotString ',' ('client_secret' '=')? client_secret=DotString (',' ('scope' '=')? scope=DotString )? ')'

;

AZURECERTIFICATEAUTH:
// for certificate based
'azurespcert' '(' ('tenant_id' '=')? tenant_id=DotString ',' ('client_id' '=')? client_id=DotString ',' ('certificate_path' '=')? certificate_path=DotString (',' ('scope' '=')? scope=DotString )?')'

;

AZUREDEVICEAUTH:
'azuredevice' '(' (('scope' '=')? scope=DotString )? ')'
;

AZURECLIAUTH:
'azurecli' '(' (('scope' '=')? scope=DotString )? ')'
;


EXTRA_ARG:
// there can be more
clear=CLEAR_SESSION | insecure=INSECURE
Expand Down
Loading

0 comments on commit 666a2b4

Please sign in to comment.