diff --git a/aws_advanced_python_wrapper/credentials_provider_factory.py b/aws_advanced_python_wrapper/credentials_provider_factory.py new file mode 100644 index 00000000..71a76e6f --- /dev/null +++ b/aws_advanced_python_wrapper/credentials_provider_factory.py @@ -0,0 +1,55 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Optional, Protocol + +import boto3 + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.utils.properties import Properties + +from abc import abstractmethod + +from aws_advanced_python_wrapper.utils.properties import WrapperProperties + + +class CredentialsProviderFactory(Protocol): + @abstractmethod + def get_aws_credentials(self, region: str, props: Properties) -> Optional[Dict[str, str]]: + ... + + +class SamlCredentialsProviderFactory(CredentialsProviderFactory): + + def get_aws_credentials(self, region: str, props: Properties) -> Optional[Dict[str, str]]: + saml_assertion: str = self.get_saml_assertion(props) + session = boto3.Session() + + sts_client = session.client( + 'sts', + region_name=region + ) + + response: Dict[str, Dict[str, str]] = sts_client.assume_role_with_saml( + RoleArn=WrapperProperties.IAM_ROLE_ARN.get(props), + PrincipalArn=WrapperProperties.IAM_IDP_ARN.get(props), + SAMLAssertion=saml_assertion, + ) + + return response.get('Credentials') + + def get_saml_assertion(self, props: Properties): + ... diff --git a/aws_advanced_python_wrapper/federated_plugin.py b/aws_advanced_python_wrapper/federated_plugin.py index 44bdcfb3..8a25c461 100644 --- a/aws_advanced_python_wrapper/federated_plugin.py +++ b/aws_advanced_python_wrapper/federated_plugin.py @@ -14,13 +14,15 @@ from __future__ import annotations -from abc import abstractmethod from html import unescape from re import DOTALL, findall, search -from typing import TYPE_CHECKING, List, Protocol -from urllib.parse import urlencode, urlparse +from typing import TYPE_CHECKING, List +from urllib.parse import urlencode -from aws_advanced_python_wrapper.utils.iamutils import IamAuthUtils, TokenInfo +from aws_advanced_python_wrapper.credentials_provider_factory import ( + CredentialsProviderFactory, SamlCredentialsProviderFactory) +from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils, TokenInfo +from aws_advanced_python_wrapper.utils.saml_utils import SamlUtils if TYPE_CHECKING: from boto3 import Session @@ -32,7 +34,6 @@ from datetime import datetime, timedelta from typing import Callable, Dict, Optional, Set -import boto3 import requests from aws_advanced_python_wrapper.errors import AwsWrapperError @@ -58,6 +59,10 @@ def __init__(self, plugin_service: PluginService, credentials_provider_factory: self._credentials_provider_factory = credentials_provider_factory self._session = session + telemetry_factory = self._plugin_service.get_telemetry_factory() + self._fetch_token_counter = telemetry_factory.create_counter("federated.fetch_token.count") + self._cache_size_gauge = telemetry_factory.create_gauge("federated.token_cache.size", lambda: len(FederatedAuthPlugin._token_cache)) + @property def subscribed_methods(self) -> Set[str]: return self._SUBSCRIBED_METHODS @@ -73,14 +78,15 @@ def connect( return self._connect(host_info, props, connect_func) def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callable) -> Connection: - self._check_idp_credentials_with_fallback(props) + SamlUtils.check_idp_credentials_with_fallback(props) host = IamAuthUtils.get_iam_host(props, host_info) port = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port) - region: str = self._get_rds_region(host, props) + region: str = IamAuthUtils.get_rds_region(self._rds_utils, host, props, self._session) - cache_key: str = self._get_cache_key( - WrapperProperties.DB_USER.get(props), + user = WrapperProperties.DB_USER.get(props) + cache_key: str = IamAuthUtils.get_cache_key( + user, host, port, region @@ -89,17 +95,17 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl token_info = FederatedAuthPlugin._token_cache.get(cache_key) if token_info is not None and not token_info.is_expired(): - logger.debug("IamAuthPlugin.UseCachedIamToken", token_info.token) + logger.debug("FederatedAuthPlugin.UseCachedToken", token_info.token) self._plugin_service.driver_dialect.set_password(props, token_info.token) else: - self._update_authentication_token(host_info, props, region, cache_key) + self._update_authentication_token(host_info, props, user, region, cache_key) - WrapperProperties.USER.set(props, WrapperProperties.DB_USER.get(props)) + WrapperProperties.USER.set(props, WrapperProperties.DB_USER.get(props)) try: return connect_func() except Exception: - self._update_authentication_token(host_info, props, region, cache_key) + self._update_authentication_token(host_info, props, user, region, cache_key) try: return connect_func() @@ -121,6 +127,7 @@ def force_connect( def _update_authentication_token(self, host_info: HostInfo, props: Properties, + user: Optional[str], region: str, cache_key: str) -> None: token_expiration_sec: int = WrapperProperties.IAM_TOKEN_EXPIRATION.get_int(props) @@ -128,70 +135,17 @@ def _update_authentication_token(self, port: int = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port) credentials: Optional[Dict[str, str]] = self._credentials_provider_factory.get_aws_credentials(region, props) - token: str = self._generate_authentication_token(props, host_info.host, port, region, credentials) - logger.debug("IamAuthPlugin.GeneratedNewIamToken", token) + self._fetch_token_counter.inc() + token: str = IamAuthUtils.generate_authentication_token( + self._plugin_service, + user, + host_info.host, + port, + region, + credentials, + self._session) WrapperProperties.PASSWORD.set(props, token) - FederatedAuthPlugin._token_cache[token] = TokenInfo(token, token_expiry) - - def _get_rds_region(self, hostname: Optional[str], props: Properties) -> str: - rds_region = WrapperProperties.IAM_REGION.get(props) - if rds_region is None or rds_region == "": - rds_region = self._rds_utils.get_rds_region(hostname) - - if not rds_region: - error_message = "RdsUtils.UnsupportedHostname" - logger.debug(error_message, hostname) - raise AwsWrapperError(Messages.get_formatted(error_message, hostname)) - - session = self._session if self._session else boto3.Session() - if rds_region not in session.get_available_regions("rds"): - error_message = "AwsSdk.UnsupportedRegion" - logger.debug(error_message, rds_region) - raise AwsWrapperError(Messages.get_formatted(error_message, rds_region)) - - return rds_region - - def _generate_authentication_token(self, - props: Properties, - host_name: Optional[str], - port: Optional[int], - region: Optional[str], - credentials: Optional[Dict[str, str]]) -> str: - session = self._session if self._session else boto3.Session() - - if credentials is not None: - client = session.client( - 'rds', - region_name=region, - aws_access_key_id=credentials.get('AccessKeyId'), - aws_secret_access_key=credentials.get('SecretAccessKey'), - aws_session_token=credentials.get('SessionToken') - ) - else: - client = session.client( - 'rds', - region_name=region - ) - - user = WrapperProperties.USER.get(props) - token = client.generate_db_auth_token( - DBHostname=host_name, - Port=port, - DBUsername=user - ) - - client.close() - - return token - - def _get_cache_key(self, user: Optional[str], hostname: Optional[str], port: int, region: Optional[str]) -> str: - return f"{region}:{hostname}:{port}:{user}" - - def _check_idp_credentials_with_fallback(self, props: Properties) -> None: - if WrapperProperties.IDP_USERNAME.get(props) is None: - WrapperProperties.IDP_USERNAME.set(props, WrapperProperties.USER.name) - if WrapperProperties.IDP_PASSWORD.get(props) is None: - WrapperProperties.IDP_PASSWORD.set(props, WrapperProperties.PASSWORD.name) + FederatedAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry) class FederatedAuthPluginFactory(PluginFactory): @@ -205,35 +159,6 @@ def get_credentials_provider_factory(self, plugin_service: PluginService, props: raise AwsWrapperError(Messages.get_formatted("FederatedAuthPluginFactory.UnsupportedIdp", idp_name)) -class CredentialsProviderFactory(Protocol): - @abstractmethod - def get_aws_credentials(self, region: str, props: Properties) -> Optional[Dict[str, str]]: - ... - - -class SamlCredentialsProviderFactory(CredentialsProviderFactory): - - def get_aws_credentials(self, region: str, props: Properties) -> Optional[Dict[str, str]]: - saml_assertion: str = self.get_saml_assertion(props) - session = boto3.Session() - - sts_client = session.client( - 'sts', - region_name=region - ) - - response: Dict[str, Dict[str, str]] = sts_client.assume_role_with_saml( - RoleArn=WrapperProperties.IAM_ROLE_ARN.get(props), - PrincipalArn=WrapperProperties.IAM_IDP_ARN.get(props), - SAMLAssertion=saml_assertion - ) - - return response.get('Credentials') - - def get_saml_assertion(self, props: Properties): - ... - - class AdfsCredentialsProviderFactory(SamlCredentialsProviderFactory): _INPUT_TAG_PATTERN = r"" _FORM_ACTION_PATTERN = r" str: logger.debug("AdfsCredentialsProviderFactory.SignOnPageUrl", url) - self._validate_url(url) + SamlUtils.validate_url(url) r = requests.get(url, verify=WrapperProperties.SSL_SECURE.get_bool(props), timeout=WrapperProperties.HTTP_REQUEST_TIMEOUT.get_int(props)) - # Check HTTP Status Code is 2xx Success - if r.status_code / 100 != 2: - error_message = "AdfsCredentialsProviderFactory.SignOnPageRequestFailed" - logger.debug(error_message, r.status_code, r.reason, r.text) - raise AwsWrapperError(Messages.get_formatted(error_message, r.status_code, r.reason, r.text)) - + SamlUtils.validate_response(r) return r.text def _post_form_action_body(self, uri: str, parameters: Dict[str, str], props: Properties) -> str: logger.debug("AdfsCredentialsProviderFactory.SignOnPagePostActionUrl", uri) - self._validate_url(uri) + SamlUtils.validate_url(uri) r = requests.post(uri, data=urlencode(parameters), verify=WrapperProperties.SSL_SECURE.get_bool(props), timeout=WrapperProperties.HTTP_REQUEST_TIMEOUT.get_int(props)) # Check HTTP Status Code is 2xx Success - if r.status_code / 100 != 2: - error_message = "AdfsCredentialsProviderFactory.SignOnPagePostActionRequestFailed" - logger.debug(error_message, r.status_code, r.reason, r.text) - raise AwsWrapperError( - Messages.get_formatted(error_message, r.status_code, r.reason, r.text)) - + SamlUtils.validate_response(r) return r.text def _get_sign_in_page_url(self, props) -> str: @@ -308,7 +223,7 @@ def _get_sign_in_page_url(self, props) -> str: relaying_party_id = WrapperProperties.RELAYING_PARTY_ID.get(props) url = f"https://{idp_endpoint}:{idp_port}/adfs/ls/IdpInitiatedSignOn.aspx?loginToRp={relaying_party_id}" if idp_endpoint is None or relaying_party_id is None: - error_message = "AdfsCredentialsProviderFactory.InvalidHttpsUrl" + error_message = "SamlUtils.InvalidHttpsUrl" logger.debug(error_message, url) raise AwsWrapperError(Messages.get_formatted(error_message, url)) @@ -319,7 +234,7 @@ def _get_form_action_url(self, props: Properties, action: str) -> str: idp_port = WrapperProperties.IDP_PORT.get(props) url = f"https://{idp_endpoint}:{idp_port}{action}" if idp_endpoint is None: - error_message = "AdfsCredentialsProviderFactory.InvalidHttpsUrl" + error_message = "SamlUtils.InvalidHttpsUrl" logger.debug(error_message, url) raise AwsWrapperError( Messages.get_formatted(error_message, url)) @@ -373,10 +288,3 @@ def _get_form_action_from_html_body(self, body: str) -> str: return unescape(match.group(1)) return "" - - def _validate_url(self, url: str) -> None: - result = urlparse(url) - if not result.scheme or not search(self._HTTPS_URL_PATTERN, url): - error_message = "AdfsCredentialsProviderFactory.InvalidHttpsUrl" - logger.debug(error_message, url) - raise AwsWrapperError(Messages.get_formatted(error_message, url)) diff --git a/aws_advanced_python_wrapper/iam_plugin.py b/aws_advanced_python_wrapper/iam_plugin.py index 6d128976..c4210d8f 100644 --- a/aws_advanced_python_wrapper/iam_plugin.py +++ b/aws_advanced_python_wrapper/iam_plugin.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING -from aws_advanced_python_wrapper.utils.iamutils import IamAuthUtils, TokenInfo +from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils, TokenInfo if TYPE_CHECKING: from boto3 import Session @@ -28,8 +28,6 @@ from datetime import datetime, timedelta from typing import Callable, Dict, Optional, Set -import boto3 - from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory from aws_advanced_python_wrapper.utils.log import Logger @@ -37,8 +35,6 @@ from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils -from aws_advanced_python_wrapper.utils.telemetry.telemetry import \ - TelemetryTraceLevel logger = Logger(__name__) @@ -75,17 +71,18 @@ def connect( return self._connect(host_info, props, connect_func) def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callable) -> Connection: - if not WrapperProperties.USER.get(props): - raise AwsWrapperError(Messages.get_formatted("IamPlugin.IsNoneOrEmpty", WrapperProperties.USER.name)) + user = WrapperProperties.USER.get(props) + if not user: + raise AwsWrapperError(Messages.get_formatted("IamAuthPlugin.IsNoneOrEmpty", WrapperProperties.USER.name)) host = IamAuthUtils.get_iam_host(props, host_info) region = WrapperProperties.IAM_REGION.get(props) \ - if WrapperProperties.IAM_REGION.get(props) else self._get_rds_region(host) + if WrapperProperties.IAM_REGION.get(props) else IamAuthUtils.get_rds_region(self._rds_utils, host, props, self._session) port = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port) token_expiration_sec: int = WrapperProperties.IAM_EXPIRATION.get_int(props) - cache_key: str = self._get_cache_key( - WrapperProperties.USER.get(props), + cache_key: str = IamAuthUtils.get_cache_key( + user, host, port, region @@ -98,8 +95,8 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl self._plugin_service.driver_dialect.set_password(props, token_info.token) else: token_expiry = datetime.now() + timedelta(seconds=token_expiration_sec) - token: str = self._generate_authentication_token(props, host, port, region) - logger.debug("IamAuthPlugin.GeneratedNewIamToken", token) + self._fetch_token_counter.inc() + token: str = IamAuthUtils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session) self._plugin_service.driver_dialect.set_password(props, token) IamAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry) @@ -116,10 +113,10 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl # Login unsuccessful with cached token # Try to generate a new token and try to connect again token_expiry = datetime.now() + timedelta(seconds=token_expiration_sec) - token = self._generate_authentication_token(props, host, port, region) - logger.debug("IamAuthPlugin.GeneratedNewIamToken", token) + self._fetch_token_counter.inc() + token = IamAuthUtils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session) self._plugin_service.driver_dialect.set_password(props, token) - IamAuthPlugin._token_cache[token] = TokenInfo(token, token_expiry) + IamAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry) try: return connect_func() @@ -136,59 +133,6 @@ def force_connect( force_connect_func: Callable) -> Connection: return self._connect(host_info, props, force_connect_func) - def _generate_authentication_token(self, - props: Properties, - hostname: Optional[str], - port: Optional[int], - region: Optional[str]) -> str: - telemetry_factory = self._plugin_service.get_telemetry_factory() - context = telemetry_factory.open_telemetry_context("fetch IAM token", TelemetryTraceLevel.NESTED) - self._fetch_token_counter.inc() - - try: - session = self._session if self._session else boto3.Session() - client = session.client( - 'rds', - region_name=region, - ) - - user = WrapperProperties.USER.get(props) - - token = client.generate_db_auth_token( - DBHostname=hostname, - Port=port, - DBUsername=user - ) - - client.close() - - return token - except Exception as ex: - context.set_success(False) - context.set_exception(ex) - raise ex - finally: - context.close_context() - - def _get_cache_key(self, user: Optional[str], hostname: Optional[str], port: int, region: Optional[str]) -> str: - return f"{region}:{hostname}:{port}:{user}" - - def _get_rds_region(self, hostname: Optional[str]) -> str: - rds_region = self._rds_utils.get_rds_region(hostname) if hostname else None - - if not rds_region: - exception_message = "RdsUtils.UnsupportedHostname" - logger.debug(exception_message, hostname) - raise AwsWrapperError(Messages.get_formatted(exception_message, hostname)) - - session = self._session if self._session else boto3.Session() - if rds_region not in session.get_available_regions("rds"): - exception_message = "AwsSdk.UnsupportedRegion" - logger.debug(exception_message, rds_region) - raise AwsWrapperError(Messages.get_formatted(exception_message, rds_region)) - - return rds_region - class IamAuthPluginFactory(PluginFactory): def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: diff --git a/aws_advanced_python_wrapper/okta_plugin.py b/aws_advanced_python_wrapper/okta_plugin.py new file mode 100644 index 00000000..7e6b26fd --- /dev/null +++ b/aws_advanced_python_wrapper/okta_plugin.py @@ -0,0 +1,225 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from datetime import datetime, timedelta +from html import unescape +from re import search +from typing import TYPE_CHECKING, Callable, Dict, Optional, Set + +from aws_advanced_python_wrapper.credentials_provider_factory import ( + CredentialsProviderFactory, SamlCredentialsProviderFactory) +from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils, TokenInfo +from aws_advanced_python_wrapper.utils.saml_utils import SamlUtils + +if TYPE_CHECKING: + from boto3 import Session + from aws_advanced_python_wrapper.driver_dialect import DriverDialect + from aws_advanced_python_wrapper.hostinfo import HostInfo + from aws_advanced_python_wrapper.pep249 import Connection + from aws_advanced_python_wrapper.plugin_service import PluginService + +import requests + +from aws_advanced_python_wrapper.errors import AwsWrapperError +from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory +from aws_advanced_python_wrapper.utils.log import Logger +from aws_advanced_python_wrapper.utils.messages import Messages +from aws_advanced_python_wrapper.utils.properties import (Properties, + WrapperProperties) +from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils + +logger = Logger(__name__) + + +class OktaAuthPlugin(Plugin): + _SUBSCRIBED_METHODS: Set[str] = {"connect", "force_connect"} + + _rds_utils: RdsUtils = RdsUtils() + _token_cache: Dict[str, TokenInfo] = {} + + def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory, session: Optional[Session] = None): + self._plugin_service = plugin_service + self._credentials_provider_factory = credentials_provider_factory + self._session = session + + telemetry_factory = self._plugin_service.get_telemetry_factory() + self._fetch_token_counter = telemetry_factory.create_counter("okta.fetch_token.count") + self._cache_size_gauge = telemetry_factory.create_gauge("okta.token_cache.size", lambda: len(OktaAuthPlugin._token_cache)) + + @property + def subscribed_methods(self) -> Set[str]: + return self._SUBSCRIBED_METHODS + + def connect( + self, + target_driver_func: Callable, + driver_dialect: DriverDialect, + host_info: HostInfo, + props: Properties, + is_initial_connection: bool, + connect_func: Callable) -> Connection: + return self._connect(host_info, props, connect_func) + + def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callable) -> Connection: + SamlUtils.check_idp_credentials_with_fallback(props) + + host = IamAuthUtils.get_iam_host(props, host_info) + port = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port) + region: str = IamAuthUtils.get_rds_region(self._rds_utils, host, props, self._session) + + user = WrapperProperties.DB_USER.get(props) + cache_key: str = IamAuthUtils.get_cache_key( + user, + host, + port, + region + ) + + token_info = OktaAuthPlugin._token_cache.get(cache_key) + + if token_info is not None and not token_info.is_expired(): + logger.debug("OktaAuthPlugin.UseCachedToken", token_info.token) + self._plugin_service.driver_dialect.set_password(props, token_info.token) + else: + self._update_authentication_token(host_info, props, user, region, cache_key) + + WrapperProperties.USER.set(props, WrapperProperties.DB_USER.get(props)) + + try: + return connect_func() + except Exception: + self._update_authentication_token(host_info, props, user, region, cache_key) + + try: + return connect_func() + except Exception as e: + error_message = "OktaAuthPlugin.UnhandledException" + logger.debug(error_message, e) + raise AwsWrapperError(Messages.get_formatted(error_message, e)) from e + + def force_connect( + self, + target_driver_func: Callable, + driver_dialect: DriverDialect, + host_info: HostInfo, + props: Properties, + is_initial_connection: bool, + force_connect_func: Callable) -> Connection: + return self._connect(host_info, props, force_connect_func) + + def _update_authentication_token(self, + host_info: HostInfo, + props: Properties, + user: Optional[str], + region: str, + cache_key: str) -> None: + token_expiration_sec: int = WrapperProperties.IAM_TOKEN_EXPIRATION.get_int(props) + token_expiry: datetime = datetime.now() + timedelta(seconds=token_expiration_sec) + port: int = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port) + credentials: Optional[Dict[str, str]] = self._credentials_provider_factory.get_aws_credentials(region, props) + + token: str = IamAuthUtils.generate_authentication_token( + self._plugin_service, + user, + host_info.host, + port, + region, + credentials, + self._session + ) + WrapperProperties.PASSWORD.set(props, token) + OktaAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry) + + +class OktaCredentialsProviderFactory(SamlCredentialsProviderFactory): + _SAML_RESPONSE_PATTERN = r"\"SAMLResponse\" .* value=\"(?P[^\"]+)\"" + _SAML_RESPONSE_PATTERN_GROUP = "saml" + _HTTPS_URL_PATTERN = r"^(https)://[-a-zA-Z0-9+&@#/%?=~_!:,.']*[-a-zA-Z0-9+&@#/%=~_']" + _OKTA_AWS_APP_NAME = "amazon_aws" + _ONE_TIME_TOKEN = "onetimetoken" + _SESSION_TOKEN = "sessionToken" + + def __init__(self, plugin_service: PluginService, props: Properties): + self._plugin_service = plugin_service + self._properties = props + + def _get_session_token(self, props: Properties) -> str: + idp_endpoint = WrapperProperties.IDP_ENDPOINT.get(props) + idp_user = WrapperProperties.IDP_USERNAME.get(props) + idp_password = WrapperProperties.IDP_PASSWORD.get(props) + + session_token_endpoint = f"https://{idp_endpoint}/api/v1/authn" + + request_body = { + "username": idp_user, + "password": idp_password, + } + try: + r = requests.post(session_token_endpoint, + headers={'Content-Type': 'application/json', 'Accept': 'application/json'}, + json=request_body, + verify=WrapperProperties.SSL_SECURE.get_bool(props), + timeout=WrapperProperties.HTTP_REQUEST_TIMEOUT.get_int(props)) + if r.status_code / 100 != 2: + error_message = "SamlUtils.RequestFailed" + logger.debug(error_message, r.status_code, r.reason, r.text) + raise AwsWrapperError(Messages.get_formatted(error_message, r.status_code, r.reason, r.text)) + return r.json().get(OktaCredentialsProviderFactory._SESSION_TOKEN) + except IOError as e: + error_message = "OktaAuthPlugin.UnhandledException" + logger.debug(error_message, e) + raise AwsWrapperError(Messages.get_formatted(error_message, e)) + + def _get_saml_url(self, props: Properties) -> str: + idp_endpoint = WrapperProperties.IDP_ENDPOINT.get(props) + app_id = WrapperProperties.APP_ID.get(props) + return f"https://{idp_endpoint}/app/{OktaCredentialsProviderFactory._OKTA_AWS_APP_NAME}/{app_id}/sso/saml" + + def get_saml_assertion(self, props: Properties): + try: + one_time_token = self._get_session_token(props) + uri = self._get_saml_url(props) + SamlUtils.validate_url(uri) + + logger.debug("OktaCredentialsProviderFactory.SamlAssertionUrl", uri) + r = requests.get(uri, + params={OktaCredentialsProviderFactory._ONE_TIME_TOKEN: one_time_token}, + verify=WrapperProperties.SSL_SECURE.get_bool(props), + timeout=WrapperProperties.HTTP_REQUEST_TIMEOUT.get_int(props)) + + SamlUtils.validate_response(r) + content = r.text + match = search(OktaCredentialsProviderFactory._SAML_RESPONSE_PATTERN, content) + if not match: + error_message = "AdfsCredentialsProviderFactory.FailedLogin" + logger.debug(error_message, content) + raise AwsWrapperError(Messages.get_formatted(error_message, content)) + + # return SAML Response value + return unescape(match.group(self._SAML_RESPONSE_PATTERN_GROUP)) + + except IOError as e: + error_message = "OktaAuthPlugin.UnhandledException" + logger.debug(error_message, e) + raise AwsWrapperError(Messages.get_formatted(error_message, e)) + + +class OktaAuthPluginFactory(PluginFactory): + def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + return OktaAuthPlugin(plugin_service, self.get_credentials_provider_factory(plugin_service, props)) + + def get_credentials_provider_factory(self, plugin_service: PluginService, props: Properties) -> OktaCredentialsProviderFactory: + return OktaCredentialsProviderFactory(plugin_service, props) diff --git a/aws_advanced_python_wrapper/plugin_service.py b/aws_advanced_python_wrapper/plugin_service.py index 7ca755e0..69a458c7 100644 --- a/aws_advanced_python_wrapper/plugin_service.py +++ b/aws_advanced_python_wrapper/plugin_service.py @@ -22,6 +22,7 @@ FastestResponseStrategyPluginFactory from aws_advanced_python_wrapper.federated_plugin import \ FederatedAuthPluginFactory +from aws_advanced_python_wrapper.okta_plugin import OktaAuthPluginFactory from aws_advanced_python_wrapper.states.session_state_service import ( SessionStateService, SessionStateServiceImpl) @@ -625,6 +626,7 @@ class PluginManager(CanReleaseResources): "execute_time": ExecuteTimePluginFactory, "dev": DeveloperPluginFactory, "federated_auth": FederatedAuthPluginFactory, + "okta": OktaAuthPluginFactory, "initial_connection": AuroraInitialConnectionStrategyPluginFactory } @@ -643,10 +645,11 @@ class PluginManager(CanReleaseResources): FastestResponseStrategyPluginFactory: 600, IamAuthPluginFactory: 700, AwsSecretsManagerPluginFactory: 800, + FederatedAuthPluginFactory: 900, + OktaAuthPluginFactory: 1000, ConnectTimePluginFactory: WEIGHT_RELATIVE_TO_PRIOR_PLUGIN, ExecuteTimePluginFactory: WEIGHT_RELATIVE_TO_PRIOR_PLUGIN, - DeveloperPluginFactory: WEIGHT_RELATIVE_TO_PRIOR_PLUGIN, - FederatedAuthPluginFactory: WEIGHT_RELATIVE_TO_PRIOR_PLUGIN + DeveloperPluginFactory: WEIGHT_RELATIVE_TO_PRIOR_PLUGIN } def __init__( diff --git a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties index bde6fca8..5fbca7b5 100644 --- a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties +++ b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties @@ -22,10 +22,8 @@ AuroraInitialConnectionStrategyPlugin.UnsupportedStrategy=[AuroraInitialConnecti AdfsCredentialsProviderFactory.FailedLogin=[AdfsCredentialsProviderFactory] Failed login. Could not obtain SAML Assertion from ADFS SignOn Page POST response: '{}' AdfsCredentialsProviderFactory.GetSamlAssertionFailed=[AdfsCredentialsProviderFactory] Failed to get SAML Assertion due to exception: '{}' -AdfsCredentialsProviderFactory.InvalidHttpsUrl=[AdfsCredentialsProviderFactory] Invalid HTTPS URL: '{}' AdfsCredentialsProviderFactory.SignOnPagePostActionUrl=[AdfsCredentialsProviderFactory] ADFS SignOn Action URL: '{}' AdfsCredentialsProviderFactory.SignOnPagePostActionRequestFailed=[AdfsCredentialsProviderFactory] ADFS SignOn Page POST action failed with HTTP status '{}', reason phrase '{}', and response '{}' -AdfsCredentialsProviderFactory.SignOnPageRequestFailed=[AdfsCredentialsProviderFactory] ADFS SignOn Page Request Failed with HTTP status '{}', reason phrase '{}', and response '{}' AdfsCredentialsProviderFactory.SignOnPageUrl=[AdfsCredentialsProviderFactory] ADFS SignOn URL: '{}' AwsSdk.UnsupportedRegion=[AwsSdk] Unsupported AWS region {}. For supported regions please read https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts.RegionsAndAvailabilityZones.html @@ -102,6 +100,15 @@ FastestResponseStrategyPlugin.RandomHostSelected=[FastestResponseStrategyPlugin] FastestResponseStrategyPlugin.UnsupportedHostSelectorStrategy=[FastestResponseStrategyPlugin] Unsupported host selector strategy: '{}'. To use the fastest response strategy plugin, please ensure the property reader_host_selector_strategy is set to fastest_response. FederatedAuthPlugin.UnhandledException=[FederatedAuthPlugin] Unhandled exception: '{}' +FederatedAuthPlugin.UseCachedToken=[FederatedAuthPlugin] Used cached authentication token = {} + +OktaAuthPlugin.UnhandledException=[OktaAuthPlugin] Unhandled exception: '{}' +OktaAuthPlugin.UseCachedToken=[OktaAuthPlugin] Used cached authentication token = {} + +OktaCredentialsProviderFactory.SamlAssertionUrl=[OktaCredentialsProviderFactory] Retrieving SAML assertion from URL: '{}' + +SamlUtils.InvalidHttpsUrl=[SamlUtils] Invalid HTTPS URL: '{}' +SamlUtils.RequestFailed=[SamlUtils] Request Failed with HTTP status '{}', reason phrase '{}', and response '{}' FederatedAuthPluginFactory.UnsupportedIdp=[FederatedAuthPluginFactory] Unsupported Identity Provider '{}'. Please visit to the documentation for supported Identity Providers. @@ -123,13 +130,13 @@ HostSelector.NoEligibleHost=[HostSelector] No Eligible Hosts Found. HostSelector.NoHostsMatchingRole=[HostSelector] No hosts were found matching the requested role: '{}'. IamAuthPlugin.ConnectException=[IamAuthPlugin] Error occurred while opening a connection: {} -IamAuthPlugin.GeneratedNewIamToken=[IamAuthPlugin] Generated new IAM token = {} IamAuthPlugin.InvalidPort=[IamAuthPlugin] Port number: {} is not valid. Port number should be greater than zero. Falling back to default port. IamAuthPlugin.NoValidPort=[IamAuthPlugin] Unable to determine a valid port. IamAuthPlugin.UnhandledException=[IamAuthPlugin] Unhandled exception: {} IamAuthPlugin.UseCachedIamToken=[IamAuthPlugin] Used cached IAM token = {} -IAMAuthPlugin.InvalidHost=[IamAuthPlugin] Invalid IAM host {}. The IAM host must be a valid RDS or Aurora endpoint. -IamPlugin.IsNoneOrEmpty=[IamPlugin] Property "{}" is None or empty. +IamAuthPlugin.InvalidHost=[IamAuthPlugin] Invalid IAM host {}. The IAM host must be a valid RDS or Aurora endpoint. +IamAuthPlugin.IsNoneOrEmpty=[IamAuthPlugin] Property "{}" is None or empty. +IamAuthUtils.GeneratedNewAuthToken=Generated new authentication token = {} LogUtils.Topology=[LogUtils] Topology {} diff --git a/aws_advanced_python_wrapper/utils/iam_utils.py b/aws_advanced_python_wrapper/utils/iam_utils.py new file mode 100644 index 00000000..ba5a6924 --- /dev/null +++ b/aws_advanced_python_wrapper/utils/iam_utils.py @@ -0,0 +1,153 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING, Dict, Optional + +import boto3 + +from aws_advanced_python_wrapper.errors import AwsWrapperError +from aws_advanced_python_wrapper.utils.log import Logger +from aws_advanced_python_wrapper.utils.messages import Messages +from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType +from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils +from aws_advanced_python_wrapper.utils.telemetry.telemetry import \ + TelemetryTraceLevel + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.hostinfo import HostInfo + from aws_advanced_python_wrapper.plugin_service import PluginService + from boto3 import Session + +from aws_advanced_python_wrapper.utils.properties import (Properties, + WrapperProperties) + +logger = Logger(__name__) + + +class IamAuthUtils: + @staticmethod + def get_iam_host(props: Properties, host_info: HostInfo): + host = WrapperProperties.IAM_HOST.get(props) if WrapperProperties.IAM_HOST.get(props) else host_info.host + IamAuthUtils.validate_iam_host(host) + return host + + @staticmethod + def validate_iam_host(host: str | None): + if host is None: + raise AwsWrapperError(Messages.get_formatted("IamAuthPlugin.InvalidHost", "[No host provided]")) + + utils = RdsUtils() + rds_type = utils.identify_rds_type(host) + if rds_type == RdsUrlType.OTHER or rds_type == RdsUrlType.IP_ADDRESS: + raise AwsWrapperError(Messages.get_formatted("IamAuthPlugin.InvalidHost", host)) + + @staticmethod + def get_port(props: Properties, host_info: HostInfo, dialect_default_port: int) -> int: + default_port: int = WrapperProperties.IAM_DEFAULT_PORT.get_int(props) + if default_port > 0: + return default_port + + if host_info.is_port_specified(): + return host_info.port + + return dialect_default_port + + @staticmethod + def get_cache_key(user: Optional[str], hostname: Optional[str], port: int, region: Optional[str]) -> str: + return f"{region}:{hostname}:{port}:{user}" + + @staticmethod + def get_rds_region(rds_utils: RdsUtils, hostname: Optional[str], props: Properties, client_session: Optional[Session] = None) -> str: + rds_region = WrapperProperties.IAM_REGION.get(props) + if rds_region is None or rds_region == "": + rds_region = rds_utils.get_rds_region(hostname) + + if not rds_region: + error_message = "RdsUtils.UnsupportedHostname" + logger.debug(error_message, hostname) + raise AwsWrapperError(Messages.get_formatted(error_message, hostname)) + + session = client_session if client_session else boto3.Session() + if rds_region not in session.get_available_regions("rds"): + error_message = "AwsSdk.UnsupportedRegion" + logger.debug(error_message, rds_region) + raise AwsWrapperError(Messages.get_formatted(error_message, rds_region)) + + return rds_region + + @staticmethod + def generate_authentication_token( + plugin_service: PluginService, + user: Optional[str], + host_name: Optional[str], + port: Optional[int], + region: Optional[str], + credentials: Optional[Dict[str, str]] = None, + client_session: Optional[Session] = None) -> str: + telemetry_factory = plugin_service.get_telemetry_factory() + context = telemetry_factory.open_telemetry_context("fetch authentication token", TelemetryTraceLevel.NESTED) + + try: + session = client_session if client_session else boto3.Session() + + if credentials is not None: + client = session.client( + 'rds', + region_name=region, + aws_access_key_id=credentials.get('AccessKeyId'), + aws_secret_access_key=credentials.get('SecretAccessKey'), + aws_session_token=credentials.get('SessionToken') + ) + else: + client = session.client( + 'rds', + region_name=region + ) + + token = client.generate_db_auth_token( + DBHostname=host_name, + Port=port, + DBUsername=user + ) + + client.close() + + logger.debug("IamAuthUtils.GeneratedNewAuthToken", token) + return token + except Exception as ex: + context.set_success(False) + context.set_exception(ex) + raise ex + finally: + context.close_context() + + +class TokenInfo: + @property + def token(self): + return self._token + + @property + def expiration(self): + return self._expiration + + def __init__(self, token: str, expiration: datetime): + self._token = token + self._expiration = expiration + + def is_expired(self) -> bool: + return datetime.now() > self._expiration diff --git a/aws_advanced_python_wrapper/utils/iamutils.py b/aws_advanced_python_wrapper/utils/iamutils.py deleted file mode 100644 index cedcd609..00000000 --- a/aws_advanced_python_wrapper/utils/iamutils.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from datetime import datetime -from typing import TYPE_CHECKING - -from aws_advanced_python_wrapper.errors import AwsWrapperError -from aws_advanced_python_wrapper.utils.messages import Messages -from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType -from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils - -if TYPE_CHECKING: - from aws_advanced_python_wrapper.hostinfo import HostInfo - -from aws_advanced_python_wrapper.utils.properties import (Properties, - WrapperProperties) - - -class IamAuthUtils: - @staticmethod - def get_iam_host(props: Properties, host_info: HostInfo): - host = WrapperProperties.IAM_HOST.get(props) if WrapperProperties.IAM_HOST.get(props) else host_info.host - IamAuthUtils.validate_iam_host(host) - return host - - @staticmethod - def validate_iam_host(host: str | None): - if host is None: - raise AwsWrapperError(Messages.get_formatted("IAMAuthPlugin.InvalidHost", "[No host provided]")) - - utils = RdsUtils() - rds_type = utils.identify_rds_type(host) - if rds_type == RdsUrlType.OTHER or rds_type == RdsUrlType.IP_ADDRESS: - raise AwsWrapperError(Messages.get_formatted("IAMAuthPlugin.InvalidHost", host)) - - @staticmethod - def get_port(props: Properties, host_info: HostInfo, dialect_default_port: int) -> int: - default_port: int = WrapperProperties.IAM_DEFAULT_PORT.get_int(props) - if default_port > 0: - return default_port - - if host_info.is_port_specified(): - return host_info.port - - return dialect_default_port - - -class TokenInfo: - @property - def token(self): - return self._token - - @property - def expiration(self): - return self._expiration - - def __init__(self, token: str, expiration: datetime): - self._token = token - self._expiration = expiration - - def is_expired(self) -> bool: - return datetime.now() > self._expiration diff --git a/aws_advanced_python_wrapper/utils/properties.py b/aws_advanced_python_wrapper/utils/properties.py index 0363f712..93825d56 100644 --- a/aws_advanced_python_wrapper/utils/properties.py +++ b/aws_advanced_python_wrapper/utils/properties.py @@ -289,6 +289,10 @@ class WrapperProperties: "The database user used to access the database", None) + # Okta + + APP_ID = WrapperProperty("app_id", "The ID of the AWS application configured on Okta", None) + # Fastest Response Strategy RESPONSE_MEASUREMENT_INTERVAL_MILLIS = WrapperProperty("response_measurement_interval_ms", "Interval in milliseconds between measuring response time to a database host", diff --git a/aws_advanced_python_wrapper/utils/saml_utils.py b/aws_advanced_python_wrapper/utils/saml_utils.py new file mode 100644 index 00000000..2070372a --- /dev/null +++ b/aws_advanced_python_wrapper/utils/saml_utils.py @@ -0,0 +1,61 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from re import search +from typing import TYPE_CHECKING +from urllib.parse import urlparse + +from aws_advanced_python_wrapper.errors import AwsWrapperError +from aws_advanced_python_wrapper.utils.log import Logger +from aws_advanced_python_wrapper.utils.messages import Messages +from aws_advanced_python_wrapper.utils.properties import (Properties, + WrapperProperties) + +if TYPE_CHECKING: + from requests import Response + + +logger = Logger(__name__) + + +class SamlUtils: + _HTTPS_URL_PATTERN = r"^(https)://[-a-zA-Z0-9+&@#/%?=~_!:,.']*[-a-zA-Z0-9+&@#/%=~_']" + + @staticmethod + def check_idp_credentials_with_fallback(props: Properties) -> None: + if WrapperProperties.IDP_USERNAME.get(props) is None: + WrapperProperties.IDP_USERNAME.set(props, WrapperProperties.USER.name) + if WrapperProperties.IDP_PASSWORD.get(props) is None: + WrapperProperties.IDP_PASSWORD.set(props, WrapperProperties.PASSWORD.name) + + @staticmethod + def validate_url(url: str) -> None: + result = urlparse(url) + if not result.scheme or not search(SamlUtils._HTTPS_URL_PATTERN, url): + error_message = "SamlUtils.InvalidHttpsUrl" + logger.debug(error_message, url) + raise AwsWrapperError(Messages.get_formatted(error_message, url)) + + @staticmethod + def validate_response(response: Response): + """Validates the HTTP response to ensure the request was successfully executed. + + :param response: HTTP response object + :raise AwsWrapperError: if response is invalid + """ + if response.status_code / 100 != 2: + error_message = "SamlUtils.RequestFailed" + logger.debug(error_message, response.status_code, response.reason, response.text) + raise AwsWrapperError(Messages.get_formatted(error_message, response.status_code, response.reason, response.text)) diff --git a/docs/examples/MySQLFederatedAuthentication.py b/docs/examples/MySQLFederatedAuthentication.py index 5d51f88e..415b0f5e 100644 --- a/docs/examples/MySQLFederatedAuthentication.py +++ b/docs/examples/MySQLFederatedAuthentication.py @@ -23,13 +23,14 @@ database="mysql", plugins="federated_auth", idp_name="adfs", + app_id="abcde1fgh3kLZTBz1S5d7", idp_endpoint="ec2amaz-ab3cdef.example.com", iam_role_arn="arn:aws:iam::123456789012:role/adfs_example_iam_role", iam_idp_arn="arn:aws:iam::123456789012:saml-provider/adfs_example", iam_region="us-east-2", idp_username="some_federated_username@example.com", idp_password="some_password", - user="john", + db_user="john", autocommit=True ) as awsconn, awsconn.cursor() as awscursor: awscursor.execute("SELECT 1") diff --git a/docs/examples/MySQLOktaAuthentication.py b/docs/examples/MySQLOktaAuthentication.py new file mode 100644 index 00000000..b4966438 --- /dev/null +++ b/docs/examples/MySQLOktaAuthentication.py @@ -0,0 +1,38 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mysql.connector + +from aws_advanced_python_wrapper import AwsWrapperConnection + +if __name__ == "__main__": + with AwsWrapperConnection.connect( + mysql.connector.Connect, + host="database.cluster-xyz.us-east-2.rds.amazonaws.com", + database="mysql", + plugins="okta", + idp_endpoint="ec2amaz-ab3cdef.example.com", + app_id="abcde1fgh3kLZTBz1S5d7", + iam_role_arn="arn:aws:iam::123456789012:role/adfs_example_iam_role", + iam_idp_arn="arn:aws:iam::123456789012:saml-provider/adfs_example", + iam_region="us-east-2", + idp_username="some_federated_username@example.com", + idp_password="some_password", + db_user="john", + autocommit=True + ) as awsconn, awsconn.cursor() as awscursor: + awscursor.execute("SELECT @@aurora_server_id") + + res = awscursor.fetchone() + print(res) diff --git a/docs/examples/PGFederatedAuthentication.py b/docs/examples/PGFederatedAuthentication.py index 44d060db..c718e091 100644 --- a/docs/examples/PGFederatedAuthentication.py +++ b/docs/examples/PGFederatedAuthentication.py @@ -29,7 +29,7 @@ iam_region="us-east-2", idp_username="some_federated_username@example.com", idp_password="some_password", - user="john", + db_user="john", autocommit=True ) as awsconn, awsconn.cursor() as awscursor: awscursor.execute("SELECT 1") diff --git a/docs/examples/PGOktaAuthentication.py b/docs/examples/PGOktaAuthentication.py new file mode 100644 index 00000000..edfbfab5 --- /dev/null +++ b/docs/examples/PGOktaAuthentication.py @@ -0,0 +1,38 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import psycopg + +from aws_advanced_python_wrapper import AwsWrapperConnection + +if __name__ == "__main__": + with AwsWrapperConnection.connect( + psycopg.Connection.connect, + host="database.cluster-xyz.us-east-2.rds.amazonaws.com", + dbname="postgres", + plugins="okta", + idp_endpoint="ec2amaz-ab3cdef.example.com", + app_id="abcde1fgh3kLZTBz1S5d7", + iam_role_arn="arn:aws:iam::123456789012:role/adfs_example_iam_role", + iam_idp_arn="arn:aws:iam::123456789012:saml-provider/adfs_example", + iam_region="us-east-2", + idp_username="some_federated_username@example.com", + idp_password="some_password", + db_user="john", + autocommit=True + ) as awsconn, awsconn.cursor() as awscursor: + awscursor.execute("SELECT * FROM aurora_db_instance_identifier()") + + res = awscursor.fetchone() + print(res) diff --git a/docs/using-the-python-driver/using-plugins/UsingTheFederatedAuthenticationPlugin.md b/docs/using-the-python-driver/using-plugins/UsingTheFederatedAuthenticationPlugin.md index 02a129b3..abfa4343 100644 --- a/docs/using-the-python-driver/using-plugins/UsingTheFederatedAuthenticationPlugin.md +++ b/docs/using-the-python-driver/using-plugins/UsingTheFederatedAuthenticationPlugin.md @@ -12,7 +12,10 @@ In the case of AD FS, the user signs into the AD FS sign in page. This generates ## How to use the Federated Authentication Plugin with the AWS Advanced Python Driver ### Enabling the Federated Authentication Plugin -Note: AWS IAM database authentication is needed to use the Federated Authentication Plugin. This is because after the plugin acquires the authentication token (ex. SAML Assertion in the case of AD FS), the authentication token is then used to acquire an AWS IAM token. The AWS IAM token is then subsequently used to access the database. +> [!NOTE]\ +> AWS IAM database authentication is needed to use the Federated Authentication Plugin. +> This is because after the plugin acquires SAML assertion from the identity provider, the SAML Assertion is then used to acquire an AWS IAM token. +> The AWS IAM token is then subsequently used to access the database. 1. Enable AWS IAM database authentication on an existing database or create a new database with AWS IAM database authentication on the AWS RDS Console: - If needed, review the documentation about [IAM authentication for MariaDB, MySQL, and PostgreSQL](https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html). @@ -41,5 +44,5 @@ Note: AWS IAM database authentication is needed to use the Federated Authenticat | `ssl_secure` | Boolean | No | Whether the SSL session is to be secure and the server's certificates will be verified | `False` | `True` | ## Sample code -[MySQLFederatedAuthentication.py](../../MySQLFederatedAuthentication.py) -[PGFederatedAuthentication.py](../../PGFederatedAuthentication.py) +[MySQLFederatedAuthentication.py](../../examples/MySQLFederatedAuthentication.py) +[PGFederatedAuthentication.py](../../examples/PGFederatedAuthentication.py) diff --git a/docs/using-the-python-driver/using-plugins/UsingTheOktaAuthenticationPlugin.md b/docs/using-the-python-driver/using-plugins/UsingTheOktaAuthenticationPlugin.md new file mode 100644 index 00000000..4424ead1 --- /dev/null +++ b/docs/using-the-python-driver/using-plugins/UsingTheOktaAuthenticationPlugin.md @@ -0,0 +1,44 @@ +# Okta Authentication Plugin + +The Okta Authentication Plugin adds support for authentication via Federated Identity and then database access via IAM. + +## What is Federated Identity +Federated Identity allows users to use the same set of credentials to access multiple services or resources across different organizations. This works by having Identity Providers (IdP) that manage and authenticate user credentials, and Service Providers (SP) that are services or resources that can be internal, external, and/or belonging to various organizations. Multiple SPs can establish trust relationships with a single IdP. + +When a user wants access to a resource, it authenticates with the IdP. From this a security token generated and is passed to the SP then grants access to said resource. +In the case of AD FS, the user signs into the AD FS sign in page. This generates a SAML Assertion which acts as a security token. The user then passes the SAML Assertion to the SP when requesting access to resources. The SP verifies the SAML Assertion and grants access to the user. + +## How to use the Okta Authentication Plugin with the AWS Advanced Python Driver + +### Enabling the Okta Authentication Plugin +> [!NOTE]\ +> AWS IAM database authentication is needed to use the Okta Authentication Plugin. This is because after the plugin +> acquires SAML assertion from the identity provider, the SAML Assertion is then used to acquire an AWS IAM token. The AWS +> IAM token is then subsequently used to access the database. + +1. Enable AWS IAM database authentication on an existing database or create a new database with AWS IAM database authentication on the AWS RDS Console: + - If needed, review the documentation about [IAM authentication for MariaDB, MySQL, and PostgreSQL](https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html). +2. Configure Okta as the AWS identity provider. + - If needed, review the documentation about [Amazon Web Services Account Federation](https://help.okta.com/en-us/content/topics/deploymentguides/aws/aws-deployment.htm) on Okta's documentation. +3. Add the plugin code `okta` to the [`plugins`](../UsingThePythonDriver.md#connection-plugin-manager-parameters) value, or to the current [driver profile](../UsingThePythonDriver.md#connection-plugin-manager-parameters). +4. Specify parameters that are required or specific to your case. + +### Federated Authentication Plugin Parameters +| Parameter | Value | Required | Description | Default Value | Example Value | +|--------------------------------|:-------:|:--------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------|--------------------------------------------------------| +| `db_user` | String | Yes | The user name of the IAM user with access to your database.
If you have previously used the IAM Authentication Plugin, this would be the same IAM user.
For information on how to connect to your Aurora Database with IAM, see this [documentation](https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/UsingWithRDS.IAMDBAuth.Connecting.html). | `None` | `some_user_name` | +| `idp_username` | String | Yes | The user name for the `idp_endpoint` server. If this parameter is not specified, the plugin will fallback to using the `user` parameter. | `None` | `jimbob@example.com` | +| `idp_password` | String | Yes | The password associated with the `idp_endpoint` username. If this parameter is not specified, the plugin will fallback to using the `password` parameter. | `None` | `some_random_password` | +| `idp_endpoint` | String | Yes | The hosting URL for the service that you are using to authenticate into AWS Aurora. | `None` | `ec2amaz-ab3cdef.example.com` | +| `iam_role_arn` | String | Yes | The ARN of the IAM Role that is to be assumed to access AWS Aurora. | `None` | `arn:aws:iam::123456789012:role/adfs_example_iam_role` | +| `iam_idp_arn` | String | Yes | The ARN of the Identity Provider. | `None` | `arn:aws:iam::123456789012:saml-provider/adfs_example` | +| `iam_region` | String | Yes | The IAM region where the IAM token is generated. | `None` | `us-east-2` | +| `iam_host` | String | No | Overrides the host that is used to generate the IAM token. | `None` | `database.cluster-hash.us-east-1.rds.amazonaws.com` | +| `iam_default_port` | String | No | This property overrides the default port that is used to generate the IAM token. The default port is determined based on the underlying driver protocol. For now, there is support for PostgreSQL and MySQL. Target drivers with different protocols will require users to provide a default port. | `None` | `1234` | +| `iam_token_expiration` | Integer | No | Overrides the default IAM token cache expiration in seconds | `870` | `123` | +| `http_request_connect_timeout` | Integer | No | The timeout value in seconds to send the HTTP request data used by the FederatedAuthPlugin. | `60` | `60` | +| `ssl_secure` | Boolean | No | Whether the SSL session is to be secure and the server's certificates will be verified | `False` | `True` | + +## Sample code +[MySQLOktaAuthentication.py](../../examples/MySQLOktaAuthentication.py) +[PGOktaAuthentication.py](../../examples/PGOktaAuthentication.py) diff --git a/tests/unit/test_federated_auth_plugin.py b/tests/unit/test_federated_auth_plugin.py index 9781451e..2014680f 100644 --- a/tests/unit/test_federated_auth_plugin.py +++ b/tests/unit/test_federated_auth_plugin.py @@ -183,10 +183,10 @@ def test_connect_with_specified_iam_host_port_region(mocker, mock_credentials_provider_factory): properties: Properties = Properties() WrapperProperties.PLUGINS.set(properties, "federated_auth") - WrapperProperties.USER.set(properties, "specifiedUser") + WrapperProperties.DB_USER.set(properties, "specifiedUser") expected_host = "pg.testdb.us-west-2.rds.amazonaws.com" - expected_port: str = 5555 + expected_port: str = "5555" expected_region = "us-west-2" WrapperProperties.IAM_HOST.set(properties, expected_host) WrapperProperties.IAM_DEFAULT_PORT.set(properties, expected_port) @@ -207,8 +207,5 @@ def test_connect_with_specified_iam_host_port_region(mocker, is_initial_connection=False, connect_func=mock_func) - mock_client.generate_db_auth_token.assert_called_with( - DBHostname="pg.testdb.us-west-2.rds.amazonaws.com", - Port=5555, - DBUsername="specifiedUser" - ) + assert WrapperProperties.USER.get(properties) == "specifiedUser" + mock_dialect.set_password.assert_called_with(properties, _TEST_TOKEN) diff --git a/tests/unit/test_okta_plugin.py b/tests/unit/test_okta_plugin.py new file mode 100644 index 00000000..236bdd35 --- /dev/null +++ b/tests/unit/test_okta_plugin.py @@ -0,0 +1,208 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from datetime import datetime, timedelta +from typing import Dict +from unittest.mock import patch + +import pytest + +from aws_advanced_python_wrapper.hostinfo import HostInfo +from aws_advanced_python_wrapper.iam_plugin import TokenInfo +from aws_advanced_python_wrapper.okta_plugin import OktaAuthPlugin +from aws_advanced_python_wrapper.utils.properties import (Properties, + WrapperProperties) + +_GENERATED_TOKEN = "generated_token" +_TEST_TOKEN = "test_token" +_DEFAULT_PG_PORT = 5432 +_PG_CACHE_KEY = f"us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:{_DEFAULT_PG_PORT}:postgresqlUser" +_DB_USER = "postgresqlUser" + +_PG_HOST_INFO = HostInfo("pg.testdb.us-east-2.rds.amazonaws.com") + +_token_cache: Dict[str, TokenInfo] = {} + + +@pytest.fixture(autouse=True) +def clear_cache(): + _token_cache.clear() + + +@pytest.fixture +def mock_session(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def mock_client(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def mock_connection(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def mock_func(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def mock_plugin_service(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def mock_dialect(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def mock_credentials_provider_factory(mocker): + return mocker.MagicMock() + + +@pytest.fixture(autouse=True) +def mock_default_behavior(mock_session, mock_client, mock_func, mock_connection, mock_plugin_service, mock_dialect, + mock_credentials_provider_factory): + mock_session.client.return_value = mock_client + mock_client.generate_db_auth_token.return_value = _TEST_TOKEN + mock_session.get_available_regions.return_value = ['us-east-1', 'us-east-2', 'us-west-1', 'us-west-2'] + mock_func.return_value = mock_connection + mock_plugin_service.driver_dialect = mock_dialect + mock_plugin_service.database_dialect = mock_dialect + mock_dialect.default_port = _DEFAULT_PG_PORT + mock_credentials_provider_factory.get_aws_credentials.return_value = {"AccessKeyId": "test-access-key", + "SecretAccessKey": "test-secret-access", + "SessionToken": "test-session-token"} + + +@patch("aws_advanced_python_wrapper.okta_plugin.OktaAuthPlugin._token_cache", _token_cache) +def test_pg_connect_valid_token_in_cache(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + properties: Properties = Properties() + WrapperProperties.PLUGINS.set(properties, "okta") + WrapperProperties.DB_USER.set(properties, _DB_USER) + initial_token = TokenInfo(_TEST_TOKEN, datetime.now() + timedelta(minutes=5)) + _token_cache[_PG_CACHE_KEY] = initial_token + + target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_session) + key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + str(_DEFAULT_PG_PORT) + ":postgesqlUser" + _token_cache[key] = initial_token + + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO, + props=properties, + is_initial_connection=False, + connect_func=mock_func) + + mock_client.generate_db_auth_token.assert_not_called() + + actual_token = _token_cache.get(_PG_CACHE_KEY) + assert _GENERATED_TOKEN != actual_token.token + assert _TEST_TOKEN == actual_token.token + assert actual_token.is_expired() is False + + +@patch("aws_advanced_python_wrapper.okta_plugin.OktaAuthPlugin._token_cache", _token_cache) +def test_expired_cached_token(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect, mock_credentials_provider_factory): + test_props: Properties = Properties({"plugins": "okta", "user": "postgresqlUser", "idp_username": "user", "idp_password": "password"}) + WrapperProperties.DB_USER.set(test_props, _DB_USER) + initial_token = TokenInfo(_TEST_TOKEN, datetime.now() - timedelta(minutes=5)) + _token_cache[_PG_CACHE_KEY] = initial_token + + target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session) + + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO, + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + mock_client.generate_db_auth_token.assert_called_with( + DBHostname="pg.testdb.us-east-2.rds.amazonaws.com", + Port=5432, + DBUsername="postgresqlUser" + ) + assert WrapperProperties.USER.get(test_props) == _DB_USER + assert WrapperProperties.PASSWORD.get(test_props) == _TEST_TOKEN + + +@patch("aws_advanced_python_wrapper.okta_plugin.OktaAuthPlugin._token_cache", _token_cache) +def test_no_cached_token(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect, mock_credentials_provider_factory): + test_props: Properties = Properties({"plugins": "okta", "user": "postgresqlUser", "idp_username": "user", "idp_password": "password"}) + WrapperProperties.DB_USER.set(test_props, _DB_USER) + + target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session) + + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO, + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + mock_client.generate_db_auth_token.assert_called_with( + DBHostname="pg.testdb.us-east-2.rds.amazonaws.com", + Port=5432, + DBUsername="postgresqlUser" + ) + assert WrapperProperties.USER.get(test_props) == _DB_USER + assert WrapperProperties.PASSWORD.get(test_props) == _TEST_TOKEN + + +@patch("aws_advanced_python_wrapper.okta_plugin.OktaAuthPlugin._token_cache", _token_cache) +def test_connect_with_specified_iam_host_port_region(mocker, + mock_plugin_service, + mock_session, + mock_func, + mock_client, + mock_dialect, + mock_credentials_provider_factory): + properties: Properties = Properties() + WrapperProperties.PLUGINS.set(properties, "okta") + WrapperProperties.DB_USER.set(properties, "specifiedUser") + + expected_host = "pg.testdb.us-west-2.rds.amazonaws.com" + expected_port: str = "5555" + expected_region = "us-west-2" + WrapperProperties.IAM_HOST.set(properties, expected_host) + WrapperProperties.IAM_DEFAULT_PORT.set(properties, expected_port) + WrapperProperties.IAM_REGION.set(properties, expected_region) + test_token_info = TokenInfo(_TEST_TOKEN, datetime.now() + timedelta(minutes=5)) + + key = "us-west-2:pg.testdb.us-west-2.rds.amazonaws.com:" + str(expected_port) + ":specifiedUser" + _token_cache[key] = test_token_info + + mock_client.generate_db_auth_token.return_value = f"{_TEST_TOKEN}:{expected_region}" + + target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session) + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=HostInfo(expected_host), + props=properties, + is_initial_connection=False, + connect_func=mock_func) + + assert WrapperProperties.USER.get(properties) == "specifiedUser" + mock_dialect.set_password.assert_called_with(properties, _TEST_TOKEN)