Skip to content

Commit

Permalink
feat: okta support (#555)
Browse files Browse the repository at this point in the history
  • Loading branch information
karenc-bq authored Jun 6, 2024
1 parent 0e9b21e commit 2546889
Show file tree
Hide file tree
Showing 18 changed files with 904 additions and 290 deletions.
55 changes: 55 additions & 0 deletions aws_advanced_python_wrapper/credentials_provider_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, Dict, Optional, Protocol

import boto3

if TYPE_CHECKING:
from aws_advanced_python_wrapper.utils.properties import Properties

from abc import abstractmethod

from aws_advanced_python_wrapper.utils.properties import WrapperProperties


class CredentialsProviderFactory(Protocol):
@abstractmethod
def get_aws_credentials(self, region: str, props: Properties) -> Optional[Dict[str, str]]:
...


class SamlCredentialsProviderFactory(CredentialsProviderFactory):

def get_aws_credentials(self, region: str, props: Properties) -> Optional[Dict[str, str]]:
saml_assertion: str = self.get_saml_assertion(props)
session = boto3.Session()

sts_client = session.client(
'sts',
region_name=region
)

response: Dict[str, Dict[str, str]] = sts_client.assume_role_with_saml(
RoleArn=WrapperProperties.IAM_ROLE_ARN.get(props),
PrincipalArn=WrapperProperties.IAM_IDP_ARN.get(props),
SAMLAssertion=saml_assertion,
)

return response.get('Credentials')

def get_saml_assertion(self, props: Properties):
...
164 changes: 36 additions & 128 deletions aws_advanced_python_wrapper/federated_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

from __future__ import annotations

from abc import abstractmethod
from html import unescape
from re import DOTALL, findall, search
from typing import TYPE_CHECKING, List, Protocol
from urllib.parse import urlencode, urlparse
from typing import TYPE_CHECKING, List
from urllib.parse import urlencode

from aws_advanced_python_wrapper.utils.iamutils import IamAuthUtils, TokenInfo
from aws_advanced_python_wrapper.credentials_provider_factory import (
CredentialsProviderFactory, SamlCredentialsProviderFactory)
from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils, TokenInfo
from aws_advanced_python_wrapper.utils.saml_utils import SamlUtils

if TYPE_CHECKING:
from boto3 import Session
Expand All @@ -32,7 +34,6 @@
from datetime import datetime, timedelta
from typing import Callable, Dict, Optional, Set

import boto3
import requests

from aws_advanced_python_wrapper.errors import AwsWrapperError
Expand All @@ -58,6 +59,10 @@ def __init__(self, plugin_service: PluginService, credentials_provider_factory:
self._credentials_provider_factory = credentials_provider_factory
self._session = session

telemetry_factory = self._plugin_service.get_telemetry_factory()
self._fetch_token_counter = telemetry_factory.create_counter("federated.fetch_token.count")
self._cache_size_gauge = telemetry_factory.create_gauge("federated.token_cache.size", lambda: len(FederatedAuthPlugin._token_cache))

@property
def subscribed_methods(self) -> Set[str]:
return self._SUBSCRIBED_METHODS
Expand All @@ -73,14 +78,15 @@ def connect(
return self._connect(host_info, props, connect_func)

def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callable) -> Connection:
self._check_idp_credentials_with_fallback(props)
SamlUtils.check_idp_credentials_with_fallback(props)

host = IamAuthUtils.get_iam_host(props, host_info)
port = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port)
region: str = self._get_rds_region(host, props)
region: str = IamAuthUtils.get_rds_region(self._rds_utils, host, props, self._session)

cache_key: str = self._get_cache_key(
WrapperProperties.DB_USER.get(props),
user = WrapperProperties.DB_USER.get(props)
cache_key: str = IamAuthUtils.get_cache_key(
user,
host,
port,
region
Expand All @@ -89,17 +95,17 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
token_info = FederatedAuthPlugin._token_cache.get(cache_key)

if token_info is not None and not token_info.is_expired():
logger.debug("IamAuthPlugin.UseCachedIamToken", token_info.token)
logger.debug("FederatedAuthPlugin.UseCachedToken", token_info.token)
self._plugin_service.driver_dialect.set_password(props, token_info.token)
else:
self._update_authentication_token(host_info, props, region, cache_key)
self._update_authentication_token(host_info, props, user, region, cache_key)

WrapperProperties.USER.set(props, WrapperProperties.DB_USER.get(props))
WrapperProperties.USER.set(props, WrapperProperties.DB_USER.get(props))

try:
return connect_func()
except Exception:
self._update_authentication_token(host_info, props, region, cache_key)
self._update_authentication_token(host_info, props, user, region, cache_key)

try:
return connect_func()
Expand All @@ -121,77 +127,25 @@ def force_connect(
def _update_authentication_token(self,
host_info: HostInfo,
props: Properties,
user: Optional[str],
region: str,
cache_key: str) -> None:
token_expiration_sec: int = WrapperProperties.IAM_TOKEN_EXPIRATION.get_int(props)
token_expiry: datetime = datetime.now() + timedelta(seconds=token_expiration_sec)
port: int = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port)
credentials: Optional[Dict[str, str]] = self._credentials_provider_factory.get_aws_credentials(region, props)

token: str = self._generate_authentication_token(props, host_info.host, port, region, credentials)
logger.debug("IamAuthPlugin.GeneratedNewIamToken", token)
self._fetch_token_counter.inc()
token: str = IamAuthUtils.generate_authentication_token(
self._plugin_service,
user,
host_info.host,
port,
region,
credentials,
self._session)
WrapperProperties.PASSWORD.set(props, token)
FederatedAuthPlugin._token_cache[token] = TokenInfo(token, token_expiry)

def _get_rds_region(self, hostname: Optional[str], props: Properties) -> str:
rds_region = WrapperProperties.IAM_REGION.get(props)
if rds_region is None or rds_region == "":
rds_region = self._rds_utils.get_rds_region(hostname)

if not rds_region:
error_message = "RdsUtils.UnsupportedHostname"
logger.debug(error_message, hostname)
raise AwsWrapperError(Messages.get_formatted(error_message, hostname))

session = self._session if self._session else boto3.Session()
if rds_region not in session.get_available_regions("rds"):
error_message = "AwsSdk.UnsupportedRegion"
logger.debug(error_message, rds_region)
raise AwsWrapperError(Messages.get_formatted(error_message, rds_region))

return rds_region

def _generate_authentication_token(self,
props: Properties,
host_name: Optional[str],
port: Optional[int],
region: Optional[str],
credentials: Optional[Dict[str, str]]) -> str:
session = self._session if self._session else boto3.Session()

if credentials is not None:
client = session.client(
'rds',
region_name=region,
aws_access_key_id=credentials.get('AccessKeyId'),
aws_secret_access_key=credentials.get('SecretAccessKey'),
aws_session_token=credentials.get('SessionToken')
)
else:
client = session.client(
'rds',
region_name=region
)

user = WrapperProperties.USER.get(props)
token = client.generate_db_auth_token(
DBHostname=host_name,
Port=port,
DBUsername=user
)

client.close()

return token

def _get_cache_key(self, user: Optional[str], hostname: Optional[str], port: int, region: Optional[str]) -> str:
return f"{region}:{hostname}:{port}:{user}"

def _check_idp_credentials_with_fallback(self, props: Properties) -> None:
if WrapperProperties.IDP_USERNAME.get(props) is None:
WrapperProperties.IDP_USERNAME.set(props, WrapperProperties.USER.name)
if WrapperProperties.IDP_PASSWORD.get(props) is None:
WrapperProperties.IDP_PASSWORD.set(props, WrapperProperties.PASSWORD.name)
FederatedAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry)


class FederatedAuthPluginFactory(PluginFactory):
Expand All @@ -205,35 +159,6 @@ def get_credentials_provider_factory(self, plugin_service: PluginService, props:
raise AwsWrapperError(Messages.get_formatted("FederatedAuthPluginFactory.UnsupportedIdp", idp_name))


class CredentialsProviderFactory(Protocol):
@abstractmethod
def get_aws_credentials(self, region: str, props: Properties) -> Optional[Dict[str, str]]:
...


class SamlCredentialsProviderFactory(CredentialsProviderFactory):

def get_aws_credentials(self, region: str, props: Properties) -> Optional[Dict[str, str]]:
saml_assertion: str = self.get_saml_assertion(props)
session = boto3.Session()

sts_client = session.client(
'sts',
region_name=region
)

response: Dict[str, Dict[str, str]] = sts_client.assume_role_with_saml(
RoleArn=WrapperProperties.IAM_ROLE_ARN.get(props),
PrincipalArn=WrapperProperties.IAM_IDP_ARN.get(props),
SAMLAssertion=saml_assertion
)

return response.get('Credentials')

def get_saml_assertion(self, props: Properties):
...


class AdfsCredentialsProviderFactory(SamlCredentialsProviderFactory):
_INPUT_TAG_PATTERN = r"<input(.+?)/>"
_FORM_ACTION_PATTERN = r"<form.*?action=\"([^\"]+)\""
Expand Down Expand Up @@ -274,32 +199,22 @@ def get_saml_assertion(self, props: Properties):

def _get_sign_in_page_body(self, url: str, props: Properties) -> str:
logger.debug("AdfsCredentialsProviderFactory.SignOnPageUrl", url)
self._validate_url(url)
SamlUtils.validate_url(url)
r = requests.get(url,
verify=WrapperProperties.SSL_SECURE.get_bool(props),
timeout=WrapperProperties.HTTP_REQUEST_TIMEOUT.get_int(props))

# Check HTTP Status Code is 2xx Success
if r.status_code / 100 != 2:
error_message = "AdfsCredentialsProviderFactory.SignOnPageRequestFailed"
logger.debug(error_message, r.status_code, r.reason, r.text)
raise AwsWrapperError(Messages.get_formatted(error_message, r.status_code, r.reason, r.text))

SamlUtils.validate_response(r)
return r.text

def _post_form_action_body(self, uri: str, parameters: Dict[str, str], props: Properties) -> str:
logger.debug("AdfsCredentialsProviderFactory.SignOnPagePostActionUrl", uri)
self._validate_url(uri)
SamlUtils.validate_url(uri)
r = requests.post(uri, data=urlencode(parameters),
verify=WrapperProperties.SSL_SECURE.get_bool(props),
timeout=WrapperProperties.HTTP_REQUEST_TIMEOUT.get_int(props))
# Check HTTP Status Code is 2xx Success
if r.status_code / 100 != 2:
error_message = "AdfsCredentialsProviderFactory.SignOnPagePostActionRequestFailed"
logger.debug(error_message, r.status_code, r.reason, r.text)
raise AwsWrapperError(
Messages.get_formatted(error_message, r.status_code, r.reason, r.text))

SamlUtils.validate_response(r)
return r.text

def _get_sign_in_page_url(self, props) -> str:
Expand All @@ -308,7 +223,7 @@ def _get_sign_in_page_url(self, props) -> str:
relaying_party_id = WrapperProperties.RELAYING_PARTY_ID.get(props)
url = f"https://{idp_endpoint}:{idp_port}/adfs/ls/IdpInitiatedSignOn.aspx?loginToRp={relaying_party_id}"
if idp_endpoint is None or relaying_party_id is None:
error_message = "AdfsCredentialsProviderFactory.InvalidHttpsUrl"
error_message = "SamlUtils.InvalidHttpsUrl"
logger.debug(error_message, url)
raise AwsWrapperError(Messages.get_formatted(error_message, url))

Expand All @@ -319,7 +234,7 @@ def _get_form_action_url(self, props: Properties, action: str) -> str:
idp_port = WrapperProperties.IDP_PORT.get(props)
url = f"https://{idp_endpoint}:{idp_port}{action}"
if idp_endpoint is None:
error_message = "AdfsCredentialsProviderFactory.InvalidHttpsUrl"
error_message = "SamlUtils.InvalidHttpsUrl"
logger.debug(error_message, url)
raise AwsWrapperError(
Messages.get_formatted(error_message, url))
Expand Down Expand Up @@ -373,10 +288,3 @@ def _get_form_action_from_html_body(self, body: str) -> str:
return unescape(match.group(1))

return ""

def _validate_url(self, url: str) -> None:
result = urlparse(url)
if not result.scheme or not search(self._HTTPS_URL_PATTERN, url):
error_message = "AdfsCredentialsProviderFactory.InvalidHttpsUrl"
logger.debug(error_message, url)
raise AwsWrapperError(Messages.get_formatted(error_message, url))
Loading

0 comments on commit 2546889

Please sign in to comment.