diff --git a/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py b/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py index a2ce609b..a5110496 100644 --- a/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py +++ b/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py @@ -64,6 +64,19 @@ def populate_opened_connection_set(self, host_info: HostInfo, conn: Connection): self._track_connection(instance_endpoint, conn) + def invalidate_current_connection(self, host_info: HostInfo, conn: Optional[Connection]): + host: Optional[str] = host_info.as_alias() \ + if self._rds_utils.is_rds_instance(host_info.host) \ + else next(alias for alias in host_info.aliases if self._rds_utils.is_rds_instance(alias)) + + if not host: + return + + connection_set: Optional[WeakSet] = self._opened_connections.get(host) + if connection_set is not None: + self._log_connection_set(host, connection_set) + connection_set.discard(conn) + def invalidate_all_connections(self, host_info: Optional[HostInfo] = None, host: Optional[FrozenSet[str]] = None): """ Invalidates all opened connections pointing to the same host in a daemon thread. @@ -77,14 +90,10 @@ def invalidate_all_connections(self, host_info: Optional[HostInfo] = None, host: self.invalidate_all_connections(host=host_info.as_aliases()) return - instance_endpoint: Optional[str] = None if host is None: return - for instance in host: - if instance is not None and self._rds_utils.is_rds_instance(instance): - instance_endpoint = instance - break + instance_endpoint = next(instance for instance in host if self._rds_utils.is_rds_instance(instance)) if not instance_endpoint: return @@ -135,8 +144,8 @@ def log_opened_connections(self): return logger.debug("OpenedConnectionTracker.OpenedConnectionsTracked", msg) - def _log_connection_set(self, host: str, conn_set: Optional[WeakSet]): - if conn_set is None or len(conn_set) == 0: + def _log_connection_set(self, host: Optional[str], conn_set: Optional[WeakSet]): + if host is None or conn_set is None or len(conn_set) == 0: return conn = "" @@ -148,13 +157,14 @@ def _log_connection_set(self, host: str, conn_set: Optional[WeakSet]): class AuroraConnectionTrackerPlugin(Plugin): - _SUBSCRIBED_METHODS: Set[str] = {"*"} + _SUBSCRIBED_METHODS: Set[str] = {"connect", "force_connect"} _current_writer: Optional[HostInfo] = None _need_update_current_writer: bool = False + _METHOD_CLOSE = "Connection.close" @property def subscribed_methods(self) -> Set[str]: - return self._SUBSCRIBED_METHODS + return AuroraConnectionTrackerPlugin._SUBSCRIBED_METHODS.union(self._plugin_service.network_bound_methods) def __init__(self, plugin_service: PluginService, @@ -201,19 +211,20 @@ def _connect(self, host_info: HostInfo, connect_func: Callable): return conn def execute(self, target: object, method_name: str, execute_func: Callable, *args: Any, **kwargs: Any) -> Any: - if self._current_writer is None or self._need_update_current_writer: - self._current_writer = self._get_writer(self._plugin_service.hosts) - self._need_update_current_writer = False + self._remember_writer() try: - return execute_func() + results = execute_func() + if method_name == AuroraConnectionTrackerPlugin._METHOD_CLOSE and self._plugin_service.current_host_info is not None: + self._tracker.invalidate_current_connection(self._plugin_service.current_host_info, self._plugin_service.current_connection) + elif self._need_update_current_writer: + self._check_writer_changed() + return results except Exception as e: # Check that e is a FailoverError and that the writer has changed - if isinstance(e, FailoverError) and self._get_writer(self._plugin_service.hosts) != self._current_writer: - self._tracker.invalidate_all_connections(host_info=self._current_writer) - self._tracker.log_opened_connections() - self._need_update_current_writer = True + if isinstance(e, FailoverError): + self._check_writer_changed() raise e def _get_writer(self, hosts: Tuple[HostInfo, ...]) -> Optional[HostInfo]: @@ -222,6 +233,23 @@ def _get_writer(self, hosts: Tuple[HostInfo, ...]) -> Optional[HostInfo]: return host return None + def _remember_writer(self): + if self._current_writer is None or self._need_update_current_writer: + self._current_writer = self._get_writer(self._plugin_service.hosts) + self._need_update_current_writer = False + + def _check_writer_changed(self): + host_info_after_failover = self._get_writer(self._plugin_service.hosts) + + if self._current_writer is None: + self._current_writer = host_info_after_failover + self._need_update_current_writer = False + elif self._current_writer != host_info_after_failover: + self._tracker.invalidate_all_connections(self._current_writer) + self._tracker.log_opened_connections() + self._current_writer = host_info_after_failover + self._need_update_current_writer = False + class AuroraConnectionTrackerPluginFactory(PluginFactory): def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: diff --git a/aws_advanced_python_wrapper/aurora_initial_connection_strategy_plugin.py b/aws_advanced_python_wrapper/aurora_initial_connection_strategy_plugin.py new file mode 100644 index 00000000..1b73983d --- /dev/null +++ b/aws_advanced_python_wrapper/aurora_initial_connection_strategy_plugin.py @@ -0,0 +1,236 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +from typing import TYPE_CHECKING, Callable, Optional, Set + +from aws_advanced_python_wrapper.utils.log import Logger + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.driver_dialect import DriverDialect + from aws_advanced_python_wrapper.pep249 import Connection + from aws_advanced_python_wrapper.plugin_service import PluginService + from aws_advanced_python_wrapper.host_list_provider import HostListProviderService + from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils + +from aws_advanced_python_wrapper.errors import AwsWrapperError +from aws_advanced_python_wrapper.host_availability import HostAvailability +from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole +from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory +from aws_advanced_python_wrapper.utils.messages import Messages +from aws_advanced_python_wrapper.utils.properties import (Properties, + WrapperProperties) +from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType + +logger = Logger(__name__) + + +class AuroraInitialConnectionStrategyPlugin(Plugin): + _plugin_service: PluginService + _host_list_provider_service: HostListProviderService + _rds_utils: RdsUtils + + @property + def subscribed_methods(self) -> Set[str]: + return {"init_host_provider", "connect", "force_connect"} + + def __init__(self, plugin_service: PluginService, properties: Properties): + self._plugin_service = plugin_service + + def init_host_provider(self, props: Properties, host_list_provider_service: HostListProviderService, init_host_provider_func: Callable): + self._host_list_provider_service = host_list_provider_service + if host_list_provider_service.is_static_host_list_provider(): + msg = Messages.get("AuroraInitialConnectionStrategyPlugin.RequireDynamicProvider") + logger.warning(msg) + raise AwsWrapperError(msg) + init_host_provider_func() + + def connect(self, target_driver_func: Callable, driver_dialect: DriverDialect, host_info: HostInfo, props: Properties, + is_initial_connection: bool, connect_func: Callable) -> Connection: + return self._connect_internal(host_info, props, is_initial_connection, connect_func) + + def force_connect(self, target_driver_func: Callable, driver_dialect: DriverDialect, host_info: HostInfo, props: Properties, + is_initial_connection: bool, force_connect_func: Callable) -> Connection: + return self._connect_internal(host_info, props, is_initial_connection, force_connect_func) + + def _connect_internal(self, host_info: HostInfo, props: Properties, is_initial_connection: bool, connect_func: Callable) -> Connection: + url_type: RdsUrlType = self._rds_utils.identify_rds_type(host_info.host) + if not url_type.is_rds_cluster: + return connect_func() + + if url_type == RdsUrlType.RDS_WRITER_CLUSTER: + writer_candidate_conn = self._get_verified_writer_connection(props, is_initial_connection, connect_func) + if writer_candidate_conn is None: + return connect_func() + return writer_candidate_conn + + if url_type == RdsUrlType.RDS_READER_CLUSTER: + reader_candidate_conn = self._get_verified_reader_connection(props, is_initial_connection, connect_func) + if reader_candidate_conn is None: + return connect_func() + return reader_candidate_conn + + # Continue with a normal workflow. + return connect_func() + + def _get_verified_writer_connection(self, props: Properties, is_initial_connection: bool, connect_func: Callable) -> Optional[Connection]: + retry_delay_ms: int = WrapperProperties.OPEN_CONNECTION_RETRY_TIMEOUT_MS.get_int(props) + end_time_nano = self._get_time() + retry_delay_ms * 1_000_000 + + writer_candidate_conn: Optional[Connection] + writer_candidate: Optional[HostInfo] + + while self._get_time() < end_time_nano: + writer_candidate_conn = None + writer_candidate = None + + try: + writer_candidate = self._get_writer() + if writer_candidate_conn is None or self._rds_utils.is_rds_cluster_dns(writer_candidate.host): + writer_candidate_conn = connect_func() + self._plugin_service.force_refresh_host_list(writer_candidate_conn) + writer_candidate = self._plugin_service.identify_connection(writer_candidate_conn) + + if writer_candidate is not None and writer_candidate.role != HostRole.WRITER: + # Shouldn't be here. But let's try again. + self._close_connection(writer_candidate_conn) + self._delay(retry_delay_ms) + continue + + if is_initial_connection: + self._host_list_provider_service.initial_connection_host_info = writer_candidate + + return writer_candidate_conn + + writer_candidate_conn = self._plugin_service.connect(writer_candidate, props) + + if self._plugin_service.get_host_role(writer_candidate_conn) != HostRole.WRITER: + self._plugin_service.force_refresh_host_list(writer_candidate_conn) + self._close_connection(writer_candidate_conn) + self._delay(retry_delay_ms) + continue + + if is_initial_connection: + self._host_list_provider_service.initial_connection_host_info = writer_candidate + return writer_candidate_conn + + except Exception as e: + if writer_candidate is not None: + self._plugin_service.set_availability(writer_candidate.as_aliases(), HostAvailability.UNAVAILABLE) + self._close_connection(writer_candidate_conn) + raise e + + return None + + def _get_verified_reader_connection(self, props: Properties, is_initial_connection: bool, connect_func: Callable) -> Optional[Connection]: + retry_delay_ms: int = WrapperProperties.OPEN_CONNECTION_RETRY_INTERVAL_MS.get_int(props) + end_time_nano = self._get_time() + WrapperProperties.OPEN_CONNECTION_RETRY_TIMEOUT_MS.get_int(props) * 1_000_000 + + reader_candidate_conn: Optional[Connection] + reader_candidate: Optional[HostInfo] + + while self._get_time() < end_time_nano: + reader_candidate_conn = None + reader_candidate = None + + try: + reader_candidate = self._get_reader(props) + if reader_candidate is None or self._rds_utils.is_rds_cluster_dns(reader_candidate.host): + # Reader not found, topology may be outdated + reader_candidate_conn = connect_func() + self._plugin_service.force_refresh_host_list(reader_candidate_conn) + reader_candidate = self._plugin_service.identify_connection(reader_candidate_conn) + + if reader_candidate is not None and reader_candidate.role != HostRole.READER: + if self._has_no_readers(): + # Cluster has no readers. Simulate Aurora reader cluster endpoint logic + if is_initial_connection and reader_candidate.host is not None: + self._host_list_provider_service.initial_connection_host_info = reader_candidate + return reader_candidate_conn + self._close_connection(reader_candidate_conn) + self._delay(retry_delay_ms) + continue + + if reader_candidate is not None and is_initial_connection: + self._host_list_provider_service.initial_connection_host_info = reader_candidate + return reader_candidate_conn + + reader_candidate_conn = self._plugin_service.connect(reader_candidate, props) + if self._plugin_service.get_host_role(reader_candidate_conn) != HostRole.READER: + # If the new connection resolves to a writer instance, this means the topology is outdated. + # Force refresh to update the topology. + self._plugin_service.force_refresh_host_list(reader_candidate_conn) + + if self._has_no_readers(): + # Cluster has no readers. Simulate Aurora reader cluster endpoint logic + if is_initial_connection: + self._host_list_provider_service.initial_connection_host_info = reader_candidate + return reader_candidate_conn + + self._close_connection(reader_candidate_conn) + self._delay(retry_delay_ms) + continue + + # Reader connection is valid and verified. + if is_initial_connection: + self._host_list_provider_service.initial_connection_host_info = reader_candidate + return reader_candidate_conn + + except Exception: + self._close_connection(reader_candidate_conn) + if reader_candidate is not None: + self._plugin_service.set_availability(reader_candidate.as_aliases(), HostAvailability.AVAILABLE) + + return None + + def _close_connection(self, connection: Optional[Connection]): + if connection is not None: + try: + connection.close() + except Exception: + # ignore + pass + + def _delay(self, delay_ms: int): + time.sleep(delay_ms / 1000) + + def _get_writer(self) -> Optional[HostInfo]: + return next(host for host in self._plugin_service.hosts if host.role == HostRole.WRITER) + + def _get_reader(self, props: Properties) -> Optional[HostInfo]: + strategy: Optional[str] = WrapperProperties.READER_HOST_SELECTOR_STRATEGY.get(props) + if strategy is not None and self._plugin_service.accepts_strategy(HostRole.READER, strategy): + try: + return self._plugin_service.get_host_info_by_strategy(HostRole.READER, strategy) + except Exception: + # Host isn't found + return None + + raise AwsWrapperError(Messages.get_formatted("AuroraInitialConnectionStrategyPlugin.UnsupportedStrategy", strategy)) + + def _has_no_readers(self) -> bool: + if len(self._plugin_service.hosts) == 0: + # Topology inconclusive. + return False + return next(host_info for host_info in self._plugin_service.hosts if host_info.role == HostRole.READER) is None + + def _get_time(self): + return time.perf_counter_ns() + + +class AuroraInitialConnectionStrategyPluginFactory(PluginFactory): + def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + return AuroraInitialConnectionStrategyPlugin(plugin_service, props) diff --git a/aws_advanced_python_wrapper/driver_configuration_profiles.py b/aws_advanced_python_wrapper/driver_configuration_profiles.py deleted file mode 100644 index a5ea171a..00000000 --- a/aws_advanced_python_wrapper/driver_configuration_profiles.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import TYPE_CHECKING, Dict, List - -if TYPE_CHECKING: - from aws_advanced_python_wrapper.plugin import PluginFactory - - -class DriverConfigurationProfiles: - _profiles: Dict[str, List[PluginFactory]] = {} - - @classmethod - def clear_profiles(cls): - cls._profiles.clear() - - @classmethod - def add_or_replace_profile(cls, profile_name: str, factories: List[PluginFactory]): - cls._profiles[profile_name] = factories - - @classmethod - def remove_profile(cls, profile_name: str): - cls._profiles.pop(profile_name) - - @classmethod - def contains_profile(cls, profile_name: str): - return profile_name in cls._profiles - - @classmethod - def get_plugin_factories(cls, profile_name: str): - return cls._profiles[profile_name] diff --git a/aws_advanced_python_wrapper/failover_plugin.py b/aws_advanced_python_wrapper/failover_plugin.py index f0776611..d4ff35d8 100644 --- a/aws_advanced_python_wrapper/failover_plugin.py +++ b/aws_advanced_python_wrapper/failover_plugin.py @@ -273,6 +273,10 @@ def _failover(self, failed_host: Optional[HostInfo]): :param failed_host: The host with network errors. """ + + if failed_host is not None: + self._plugin_service.set_availability(failed_host.as_aliases(), HostAvailability.AVAILABLE) + if self._failover_mode == FailoverMode.STRICT_WRITER: self._failover_writer() else: diff --git a/aws_advanced_python_wrapper/host_list_provider.py b/aws_advanced_python_wrapper/host_list_provider.py index f6b94d89..90d04fb0 100644 --- a/aws_advanced_python_wrapper/host_list_provider.py +++ b/aws_advanced_python_wrapper/host_list_provider.py @@ -127,7 +127,7 @@ def initial_connection_host_info(self) -> Optional[HostInfo]: ... @initial_connection_host_info.setter - def initial_connection_host_info(self, value: HostInfo): + def initial_connection_host_info(self, value: Optional[HostInfo]): ... def is_static_host_list_provider(self) -> bool: diff --git a/aws_advanced_python_wrapper/mysql_driver_dialect.py b/aws_advanced_python_wrapper/mysql_driver_dialect.py index 031ab013..0335df4d 100644 --- a/aws_advanced_python_wrapper/mysql_driver_dialect.py +++ b/aws_advanced_python_wrapper/mysql_driver_dialect.py @@ -63,7 +63,9 @@ class MySQLDriverDialect(DriverDialect): } def is_dialect(self, connect_func: Callable) -> bool: - return MySQLDriverDialect.TARGET_DRIVER_CODE in str(signature(connect_func)) + if MySQLDriverDialect.TARGET_DRIVER_CODE not in str(signature(connect_func)): + return MySQLDriverDialect.TARGET_DRIVER_CODE.lower() in (connect_func.__module__ + connect_func.__qualname__).lower() + return True def is_closed(self, conn: Connection) -> bool: if isinstance(conn, CMySQLConnection) or isinstance(conn, MySQLConnection): diff --git a/aws_advanced_python_wrapper/pg_driver_dialect.py b/aws_advanced_python_wrapper/pg_driver_dialect.py index 921e59c9..51333a75 100644 --- a/aws_advanced_python_wrapper/pg_driver_dialect.py +++ b/aws_advanced_python_wrapper/pg_driver_dialect.py @@ -58,7 +58,9 @@ class PgDriverDialect(DriverDialect): } def is_dialect(self, connect_func: Callable) -> bool: - return PgDriverDialect.TARGET_DRIVER_CODE in str(signature(connect_func)) + if PgDriverDialect.TARGET_DRIVER_CODE not in str(signature(connect_func)): + return PgDriverDialect.TARGET_DRIVER_CODE.lower() in (connect_func.__module__ + connect_func.__qualname__).lower() + return True def is_closed(self, conn: Connection) -> bool: if isinstance(conn, psycopg.Connection): diff --git a/aws_advanced_python_wrapper/plugin_service.py b/aws_advanced_python_wrapper/plugin_service.py index f49ebf2e..e06a12e0 100644 --- a/aws_advanced_python_wrapper/plugin_service.py +++ b/aws_advanced_python_wrapper/plugin_service.py @@ -26,6 +26,7 @@ from aws_advanced_python_wrapper.driver_dialect_manager import DriverDialectManager from aws_advanced_python_wrapper.pep249 import Connection from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory + from aws_advanced_python_wrapper.profiles.configuration_profile import ConfigurationProfile from threading import Event from abc import abstractmethod @@ -47,8 +48,6 @@ UnknownDatabaseDialect) from aws_advanced_python_wrapper.default_plugin import DefaultPlugin from aws_advanced_python_wrapper.developer_plugin import DeveloperPluginFactory -from aws_advanced_python_wrapper.driver_configuration_profiles import \ - DriverConfigurationProfiles from aws_advanced_python_wrapper.errors import (AwsWrapperError, QueryTimeoutError, UnsupportedOperationError) @@ -251,7 +250,6 @@ def get_telemetry_factory(self) -> TelemetryFactory: class PluginServiceImpl(PluginService, HostListProviderService, CanReleaseResources): - _host_availability_expiring_cache: CacheMap[str, HostAvailability] = CacheMap() _executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="PluginServiceImplExecutor") @@ -262,10 +260,12 @@ def __init__( props: Properties, target_func: Callable, driver_dialect_manager: DriverDialectManager, - driver_dialect: DriverDialect): + driver_dialect: DriverDialect, + profile: Optional[ConfigurationProfile] = None): self._container = container self._container.plugin_service = self self._props = props + self._configuration_profile = profile self._original_url = PropertiesUtils.get_url(props) self._host_list_provider: HostListProvider = ConnectionStringHostListProvider(self, props) @@ -279,7 +279,9 @@ def __init__( self._target_func = target_func self._driver_dialect_manager = driver_dialect_manager self._driver_dialect = driver_dialect - self._database_dialect = self._dialect_provider.get_dialect(driver_dialect.dialect_code, props) + self._database_dialect = self._configuration_profile.database_dialect \ + if self._configuration_profile is not None and self._configuration_profile.database_dialect is not None \ + else self._dialect_provider.get_dialect(driver_dialect.dialect_code, props) @property def hosts(self) -> Tuple[HostInfo, ...]: @@ -601,14 +603,18 @@ class PluginManager(CanReleaseResources): FederatedAuthPluginFactory: WEIGHT_RELATIVE_TO_PRIOR_PLUGIN } - def __init__( - self, container: PluginServiceManagerContainer, props: Properties, telemetry_factory: TelemetryFactory): + def __init__(self, + container: PluginServiceManagerContainer, + props: Properties, + telemetry_factory: TelemetryFactory, + profile: Optional[ConfigurationProfile] = None): self._props: Properties = props self._function_cache: Dict[str, Callable] = {} self._container = container self._container.plugin_manager = self self._connection_provider_manager = ConnectionProviderManager() self._telemetry_factory = telemetry_factory + self._configuration_profile: Optional[ConfigurationProfile] = profile self._plugins = self.get_plugins() @property @@ -636,12 +642,10 @@ def get_plugins(self) -> List[Plugin]: plugin_factories: List[PluginFactory] = [] plugins: List[Plugin] = [] - profile_name = WrapperProperties.PROFILE_NAME.get(self._props) - if profile_name is not None: - if not DriverConfigurationProfiles.contains_profile(profile_name): - raise AwsWrapperError( - Messages.get_formatted("PluginManager.ConfigurationProfileNotFound", profile_name)) - plugin_factories = DriverConfigurationProfiles.get_plugin_factories(profile_name) + if self._configuration_profile is not None: + factories = self._configuration_profile.plugin_factories + if factories is not None: + plugin_factories = self._configuration_profile.plugin_factories else: plugin_codes = WrapperProperties.PLUGINS.get(self._props) if plugin_codes is None: diff --git a/aws_advanced_python_wrapper/profiles/__init__.py b/aws_advanced_python_wrapper/profiles/__init__.py new file mode 100644 index 00000000..bd4acb2b --- /dev/null +++ b/aws_advanced_python_wrapper/profiles/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/aws_advanced_python_wrapper/profiles/configuration_profile.py b/aws_advanced_python_wrapper/profiles/configuration_profile.py new file mode 100644 index 00000000..bc5e984c --- /dev/null +++ b/aws_advanced_python_wrapper/profiles/configuration_profile.py @@ -0,0 +1,70 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Optional + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.connection_provider import \ + ConnectionProvider + from aws_advanced_python_wrapper.database_dialect import DatabaseDialect + from aws_advanced_python_wrapper.driver_dialect import DriverDialect + from aws_advanced_python_wrapper.exception_handling import ExceptionHandler + from aws_advanced_python_wrapper.plugin import PluginFactory + +from aws_advanced_python_wrapper.utils.properties import Properties + + +class ConfigurationProfile: + def __init__( + self, + name: str, + properties: Properties = Properties(), + plugin_factories: List[PluginFactory] = [], + dialect: Optional[DatabaseDialect] = None, + target_driver_dialect: Optional[DriverDialect] = None, + exception_handler: Optional[ExceptionHandler] = None, + connection_provider: Optional[ConnectionProvider] = None): + self._name = name + self._plugin_factories = plugin_factories + self._properties = properties + self._database_dialect = dialect + self._target_driver_dialect = target_driver_dialect + self._exception_handler = exception_handler + self._connection_provider = connection_provider + + @property + def name(self) -> str: + return self._name + + @property + def properties(self) -> Properties: + return self._properties + + @property + def plugin_factories(self) -> List[PluginFactory]: + return self._plugin_factories + + @property + def database_dialect(self) -> Optional[DatabaseDialect]: + return self._database_dialect + + @property + def target_driver_dialect(self) -> Optional[DriverDialect]: + return self._target_driver_dialect + + @property + def connection_provider(self) -> Optional[ConnectionProvider]: + return self._connection_provider diff --git a/aws_advanced_python_wrapper/profiles/configuration_profile_preset_codes.py b/aws_advanced_python_wrapper/profiles/configuration_profile_preset_codes.py new file mode 100644 index 00000000..c376ef53 --- /dev/null +++ b/aws_advanced_python_wrapper/profiles/configuration_profile_preset_codes.py @@ -0,0 +1,35 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http:#www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class ConfigurationProfilePresetCodes: + # Presets family A, B, C - no connection pool + # Presets family D, E ,F - internal connection pool + # Presets family G, H, I - external connection pool + + A0 = "A0" # Normal + A1 = "A1" # Easy + A2 = "A2" # Aggressive + B = "B" # Normal + PG_C0 = "PG_C0" # Normal + PG_C1 = "PG_C1" # Aggressive + D0 = "D0" # Normal + D1 = "D1" # Easy + E = "E" # Normal + PG_F0 = "PG_F0" # Normal + PG_F1 = "PG_F1" # Aggressive + G0 = "G0" # Normal + G1 = "G1" # Easy + H = "H" # Normal + PG_I0 = "PG_I0" # Normal + PG_I1 = "PG_I1" # Aggressive diff --git a/aws_advanced_python_wrapper/profiles/driver_configuration_profiles.py b/aws_advanced_python_wrapper/profiles/driver_configuration_profiles.py new file mode 100644 index 00000000..a4bcf187 --- /dev/null +++ b/aws_advanced_python_wrapper/profiles/driver_configuration_profiles.py @@ -0,0 +1,247 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Dict, Optional + +from aws_advanced_python_wrapper.aurora_connection_tracker_plugin import \ + AuroraConnectionTrackerPluginFactory +from aws_advanced_python_wrapper.aurora_initial_connection_strategy_plugin import \ + AuroraInitialConnectionStrategyPluginFactory +from aws_advanced_python_wrapper.failover_plugin import FailoverPluginFactory +from aws_advanced_python_wrapper.host_monitoring_plugin import \ + HostMonitoringPluginFactory +from aws_advanced_python_wrapper.profiles.configuration_profile import \ + ConfigurationProfile +from aws_advanced_python_wrapper.profiles.configuration_profile_preset_codes import \ + ConfigurationProfilePresetCodes +from aws_advanced_python_wrapper.read_write_splitting_plugin import \ + ReadWriteSplittingPluginFactory +from aws_advanced_python_wrapper.sql_alchemy_connection_provider import \ + SqlAlchemyPooledConnectionProvider +from aws_advanced_python_wrapper.stale_dns_plugin import StaleDnsPluginFactory +from aws_advanced_python_wrapper.utils.properties import (Properties, + WrapperProperties) + + +class DriverConfigurationProfiles: + _profiles: Dict[str, Optional[ConfigurationProfile]] = {} + _presets: Dict[str, ConfigurationProfile] = { + ConfigurationProfilePresetCodes.A0: ConfigurationProfile( + name=ConfigurationProfilePresetCodes.A0, + properties=Properties({WrapperProperties.CONNECT_TIMEOUT_SEC.name: 10, + WrapperProperties.SOCKET_TIMEOUT_SEC.name: 5, + WrapperProperties.TCP_KEEPALIVE.name: False, + "autocommit": True}) + ), + ConfigurationProfilePresetCodes.A1: ConfigurationProfile( + name=ConfigurationProfilePresetCodes.A1, + properties=Properties({WrapperProperties.CONNECT_TIMEOUT_SEC.name: 30, + WrapperProperties.SOCKET_TIMEOUT_SEC.name: 3, + WrapperProperties.TCP_KEEPALIVE.name: False, + "autocommit": True}) + ), + ConfigurationProfilePresetCodes.A2: ConfigurationProfile( + name=ConfigurationProfilePresetCodes.A2, + properties=Properties({WrapperProperties.CONNECT_TIMEOUT_SEC.name: 3, + WrapperProperties.SOCKET_TIMEOUT_SEC.name: 3, + WrapperProperties.TCP_KEEPALIVE.name: False, + "autocommit": True}) + ), + ConfigurationProfilePresetCodes.B: ConfigurationProfile( + name=ConfigurationProfilePresetCodes.B, + properties=Properties({WrapperProperties.CONNECT_TIMEOUT_SEC.name: 10, + WrapperProperties.SOCKET_TIMEOUT_SEC.name: 0, + WrapperProperties.TCP_KEEPALIVE.name: True, + "autocommit": True}) + ), + ConfigurationProfilePresetCodes.PG_C0: ConfigurationProfile( + name=ConfigurationProfilePresetCodes.PG_C0, + plugin_factories=[HostMonitoringPluginFactory()], + properties=Properties({WrapperProperties.CONNECT_TIMEOUT_SEC.name: 10, + WrapperProperties.SOCKET_TIMEOUT_SEC.name: 0, + WrapperProperties.FAILURE_DETECTION_COUNT.name: 5, + WrapperProperties.FAILURE_DETECTION_TIME_MS.name: 60000, + WrapperProperties.FAILURE_DETECTION_INTERVAL_MS.name: 15000, + WrapperProperties.TCP_KEEPALIVE.name: False, + "autocommit": True}) + ), + ConfigurationProfilePresetCodes.PG_C1: ConfigurationProfile( + name=ConfigurationProfilePresetCodes.PG_C1, + plugin_factories=[HostMonitoringPluginFactory()], + properties=Properties({WrapperProperties.CONNECT_TIMEOUT_SEC.name: 10, + WrapperProperties.SOCKET_TIMEOUT_SEC.name: 0, + WrapperProperties.FAILURE_DETECTION_COUNT.name: 5, + WrapperProperties.FAILURE_DETECTION_TIME_MS.name: 30000, + WrapperProperties.FAILURE_DETECTION_INTERVAL_MS.name: 5000, + "monitoring-" + WrapperProperties.CONNECT_TIMEOUT_SEC.name: 3, + "monitoring-" + WrapperProperties.SOCKET_TIMEOUT_SEC.name: 3, + WrapperProperties.TCP_KEEPALIVE.name: False, + "autocommit": True}) + ), + ConfigurationProfilePresetCodes.D0: ConfigurationProfile( + name=ConfigurationProfilePresetCodes.D0, + plugin_factories=[AuroraInitialConnectionStrategyPluginFactory(), AuroraConnectionTrackerPluginFactory(), + ReadWriteSplittingPluginFactory(), FailoverPluginFactory()], + properties=Properties({WrapperProperties.CONNECT_TIMEOUT_SEC.name: 10, + WrapperProperties.SOCKET_TIMEOUT_SEC.name: 5, + WrapperProperties.TCP_KEEPALIVE.name: False, + "autocommit": True}), + connection_provider=SqlAlchemyPooledConnectionProvider(lambda _, __: {"pool_size": 30, + "pool_recycle": 86400, + "pool_timeout": 180}) + ), + ConfigurationProfilePresetCodes.D1: ConfigurationProfile( + name=ConfigurationProfilePresetCodes.D1, + plugin_factories=[AuroraInitialConnectionStrategyPluginFactory(), AuroraConnectionTrackerPluginFactory(), + ReadWriteSplittingPluginFactory(), FailoverPluginFactory()], + properties=Properties({WrapperProperties.CONNECT_TIMEOUT_SEC.name: 30, + WrapperProperties.SOCKET_TIMEOUT_SEC.name: 30, + WrapperProperties.TCP_KEEPALIVE.name: False, + "autocommit": True}), + connection_provider=SqlAlchemyPooledConnectionProvider(lambda _, __: {"pool_size": 30, + "pool_recycle": 86400, + "pool_timeout": 10}) + ), + ConfigurationProfilePresetCodes.E: ConfigurationProfile( + name=ConfigurationProfilePresetCodes.E, + plugin_factories=[AuroraInitialConnectionStrategyPluginFactory(), AuroraConnectionTrackerPluginFactory(), + ReadWriteSplittingPluginFactory(), FailoverPluginFactory()], + properties=Properties({WrapperProperties.CONNECT_TIMEOUT_SEC.name: 10, + WrapperProperties.SOCKET_TIMEOUT_SEC.name: 0, + WrapperProperties.TCP_KEEPALIVE.name: False, + "autocommit": True}), + connection_provider=SqlAlchemyPooledConnectionProvider(lambda _, __: {"pool_size": 30, + "pool_recycle": 86400, + "pool_timeout": 10}) + ), + ConfigurationProfilePresetCodes.PG_F0: ConfigurationProfile( + name=ConfigurationProfilePresetCodes.PG_F0, + plugin_factories=[AuroraInitialConnectionStrategyPluginFactory(), AuroraConnectionTrackerPluginFactory(), + ReadWriteSplittingPluginFactory(), FailoverPluginFactory(), HostMonitoringPluginFactory()], + properties=Properties({WrapperProperties.CONNECT_TIMEOUT_SEC.name: 10, + WrapperProperties.SOCKET_TIMEOUT_SEC.name: 0, + WrapperProperties.FAILURE_DETECTION_COUNT.name: 5, + WrapperProperties.FAILURE_DETECTION_TIME_MS.name: 60000, + WrapperProperties.FAILURE_DETECTION_INTERVAL_MS.name: 15000, + "monitoring-" + WrapperProperties.CONNECT_TIMEOUT_SEC.name: 10, + "monitoring-" + WrapperProperties.SOCKET_TIMEOUT_SEC.name: 5, + WrapperProperties.TCP_KEEPALIVE.name: False, + "autocommit": True}), + connection_provider=SqlAlchemyPooledConnectionProvider(lambda _, __: {"pool_size": 30, + "pool_recycle": 86400, + "pool_timeout": 10}) + ), + ConfigurationProfilePresetCodes.PG_F1: ConfigurationProfile( + name=ConfigurationProfilePresetCodes.PG_F1, + plugin_factories=[AuroraInitialConnectionStrategyPluginFactory(), AuroraConnectionTrackerPluginFactory(), + ReadWriteSplittingPluginFactory(), FailoverPluginFactory(), HostMonitoringPluginFactory()], + properties=Properties({WrapperProperties.CONNECT_TIMEOUT_SEC.name: 10, + WrapperProperties.SOCKET_TIMEOUT_SEC.name: 0, + WrapperProperties.FAILURE_DETECTION_COUNT.name: 5, + WrapperProperties.FAILURE_DETECTION_TIME_MS.name: 30000, + WrapperProperties.FAILURE_DETECTION_INTERVAL_MS.name: 5000, + "monitoring-" + WrapperProperties.CONNECT_TIMEOUT_SEC.name: 3, + "monitoring-" + WrapperProperties.SOCKET_TIMEOUT_SEC.name: 3, + WrapperProperties.TCP_KEEPALIVE.name: False, + "autocommit": True}), + connection_provider=SqlAlchemyPooledConnectionProvider(lambda _, __: {"pool_size": 30, + "pool_recycle": 86400, + "pool_timeout": 10}) + ), + ConfigurationProfilePresetCodes.G0: ConfigurationProfile( + name=ConfigurationProfilePresetCodes.G0, + plugin_factories=[AuroraConnectionTrackerPluginFactory(), StaleDnsPluginFactory(), FailoverPluginFactory()], + properties=Properties({WrapperProperties.CONNECT_TIMEOUT_SEC.name: 10, + WrapperProperties.SOCKET_TIMEOUT_SEC.name: 5, + WrapperProperties.TCP_KEEPALIVE.name: False, + "autocommit": True}) + ), + ConfigurationProfilePresetCodes.G1: ConfigurationProfile( + name=ConfigurationProfilePresetCodes.G1, + plugin_factories=[AuroraConnectionTrackerPluginFactory(), StaleDnsPluginFactory(), FailoverPluginFactory()], + properties=Properties({WrapperProperties.CONNECT_TIMEOUT_SEC.name: 30, + WrapperProperties.SOCKET_TIMEOUT_SEC.name: 30, + WrapperProperties.TCP_KEEPALIVE.name: False, + "autocommit": True}) + ), + ConfigurationProfilePresetCodes.H: ConfigurationProfile( + name=ConfigurationProfilePresetCodes.G1, + plugin_factories=[AuroraConnectionTrackerPluginFactory(), StaleDnsPluginFactory(), FailoverPluginFactory()], + properties=Properties({WrapperProperties.CONNECT_TIMEOUT_SEC.name: 10, + WrapperProperties.SOCKET_TIMEOUT_SEC.name: 0, + WrapperProperties.TCP_KEEPALIVE.name: True, + "autocommit": True}) + ), + ConfigurationProfilePresetCodes.PG_I0: ConfigurationProfile( + name=ConfigurationProfilePresetCodes.PG_I0, + plugin_factories=[AuroraConnectionTrackerPluginFactory(), + StaleDnsPluginFactory(), + FailoverPluginFactory(), + HostMonitoringPluginFactory()], + properties=Properties({WrapperProperties.CONNECT_TIMEOUT_SEC.name: 10, + WrapperProperties.SOCKET_TIMEOUT_SEC.name: 0, + WrapperProperties.FAILURE_DETECTION_COUNT.name: 5, + WrapperProperties.FAILURE_DETECTION_TIME_MS.name: 60000, + WrapperProperties.FAILURE_DETECTION_INTERVAL_MS.name: 15000, + "monitoring-" + WrapperProperties.CONNECT_TIMEOUT_SEC.name: 10, + "monitoring-" + WrapperProperties.SOCKET_TIMEOUT_SEC.name: 3, + WrapperProperties.TCP_KEEPALIVE.name: False, + "autocommit": True}) + ), + ConfigurationProfilePresetCodes.PG_I1: ConfigurationProfile( + name=ConfigurationProfilePresetCodes.PG_I1, + plugin_factories=[AuroraConnectionTrackerPluginFactory(), + StaleDnsPluginFactory(), + FailoverPluginFactory(), + HostMonitoringPluginFactory()], + properties=Properties({WrapperProperties.CONNECT_TIMEOUT_SEC.name: 10, + WrapperProperties.SOCKET_TIMEOUT_SEC.name: 0, + WrapperProperties.FAILURE_DETECTION_COUNT.name: 3, + WrapperProperties.FAILURE_DETECTION_TIME_MS.name: 30000, + WrapperProperties.FAILURE_DETECTION_INTERVAL_MS.name: 5000, + "monitoring-" + WrapperProperties.CONNECT_TIMEOUT_SEC.name: 3, + "monitoring-" + WrapperProperties.SOCKET_TIMEOUT_SEC.name: 3, + WrapperProperties.TCP_KEEPALIVE.name: False, + "autocommit": True}) + ) + } + + @classmethod + def clear_profiles(cls): + cls._profiles.clear() + + @classmethod + def add_or_replace_profile(cls, profile_name: str, profile: Optional[ConfigurationProfile]): + cls._profiles[profile_name] = profile + + @classmethod + def remove_profile(cls, profile_name: str): + cls._profiles.pop(profile_name) + + @classmethod + def contains_profile(cls, profile_name: str): + return profile_name in cls._profiles + + @classmethod + def get_plugin_factories(cls, profile_name: str): + return cls._profiles[profile_name] + + @classmethod + def get_profile_configuration(cls, profile_name: str): + profile: Optional[ConfigurationProfile] = cls._profiles.get(profile_name) + if profile is not None: + return profile + return cls._presets.get(profile_name) diff --git a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties index 2f573bbb..041a8731 100644 --- a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties +++ b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties @@ -25,6 +25,9 @@ AdfsCredentialsProviderFactory.SignOnPagePostActionRequestFailed=[AdfsCredential AdfsCredentialsProviderFactory.SignOnPageRequestFailed=[AdfsCredentialsProviderFactory] ADFS SignOn Page Request Failed with HTTP status '{}', reason phrase '{}', and response '{}' AdfsCredentialsProviderFactory.SignOnPageUrl=[AdfsCredentialsProviderFactory] ADFS SignOn URL: '{}' +AuroraInitialConnectionStrategyPlugin.UnsupportedStrategy=Unsupported host selection strategy '{}'. +AuroraInitialConnectionStrategyPlugin.RequireDynamicProvider=Dynamic host list provider is required. + AwsSdk.UnsupportedRegion=[AwsSdk] Unsupported AWS region {}. For supported regions please read https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts.RegionsAndAvailabilityZones.html AwsSecretsManagerPlugin.ConnectException=[AwsSecretsManagerPlugin] Error occurred while opening a connection: {} @@ -170,7 +173,6 @@ OpenTelemetryFactory.WrongParameterType="[OpenTelemetryFactory] Wrong parameter Plugin.UnsupportedMethod=[Plugin] '{}' is not supported by this plugin. -PluginManager.ConfigurationProfileNotFound=PluginManager] Configuration profile '{}' not found. PluginManager.InvalidPlugin=[PluginManager] Invalid plugin requested: '{}'. PluginManager.MethodInvokedAgainstOldConnection = [PluginManager] The internal connection has changed since '{}' was created. This is likely due to failover or read-write splitting functionality. To ensure you are using the updated connection, please re-create Cursor objects after failover and/or setting readonly. PluginManager.PipelineNone=[PluginManager] A pipeline was requested but the created pipeline evaluated to None. @@ -287,6 +289,8 @@ Wrapper.ConnectMethod=[Wrapper] Target driver should be a target driver's connec Wrapper.RequiredTargetDriver=[Wrapper] Target driver is required. Wrapper.UnsupportedAttribute=[Wrapper] Target driver does not have the attribute: '{}' Wrapper.Properties=[Wrapper] "Connection Properties: " +Wrapper.ConfigurationProfileNotFound=[Wrapper] Configuration profile '{}' not found. + WriterFailoverHandler.AlreadyWriter=[WriterFailoverHandler] Current reader connection is actually a new writer connection. WriterFailoverHandler.CurrentTopologyNone=[WriterFailoverHandler] Current topology cannot be None. diff --git a/aws_advanced_python_wrapper/utils/properties.py b/aws_advanced_python_wrapper/utils/properties.py index daedd65f..6f1d1648 100644 --- a/aws_advanced_python_wrapper/utils/properties.py +++ b/aws_advanced_python_wrapper/utils/properties.py @@ -319,6 +319,16 @@ class WrapperProperties: False ) + # Aurora Initial Connection Strategy Plugin + READER_INITIAL_HOST_SELECTOR_STRATEGY = WrapperProperty("reader_initial_connection_host_selector_strategy", + "The strategy that should be used to select a " + "new reader host while opening a new connection.", + "random") + + OPEN_CONNECTION_RETRY_TIMEOUT_MS = WrapperProperty("open_connection_retry_timeout_ms", + "Maximum allowed time in milliseconds for the retries opening a connection.", 30_000) + OPEN_CONNECTION_RETRY_INTERVAL_MS = WrapperProperty("open_connection_retry_interval_ms", "Time between each retry of opening a connection.", 1000) + class PropertiesUtils: @staticmethod diff --git a/aws_advanced_python_wrapper/wrapper.py b/aws_advanced_python_wrapper/wrapper.py index c6318a02..bb53d165 100644 --- a/aws_advanced_python_wrapper/wrapper.py +++ b/aws_advanced_python_wrapper/wrapper.py @@ -17,6 +17,12 @@ from typing import (TYPE_CHECKING, Any, Callable, Iterator, List, Optional, Union) +if TYPE_CHECKING: + from aws_advanced_python_wrapper.profiles.configuration_profile import ConfigurationProfile + +from aws_advanced_python_wrapper.profiles.driver_configuration_profiles import \ + DriverConfigurationProfiles + if TYPE_CHECKING: from aws_advanced_python_wrapper.host_list_provider import HostListProviderService @@ -32,7 +38,8 @@ from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, - PropertiesUtils) + PropertiesUtils, + WrapperProperties) from aws_advanced_python_wrapper.utils.telemetry.default_telemetry_factory import \ DefaultTelemetryFactory from aws_advanced_python_wrapper.utils.telemetry.telemetry import \ @@ -137,10 +144,19 @@ def connect( try: driver_dialect_manager: DriverDialectManager = DriverDialectManager() driver_dialect = driver_dialect_manager.get_dialect(target_func, props) + + profile_name: Optional[str] = WrapperProperties.PROFILE_NAME.get(props) + configuration_profile: Optional[ConfigurationProfile] = None + if profile_name: + configuration_profile = DriverConfigurationProfiles.get_profile_configuration(profile_name) + if configuration_profile is None: + raise AwsWrapperError(Messages.get_formatted("Wrapper.ConfigurationProfileNotFound", configuration_profile)) + props = Properties({**props, **configuration_profile.properties}) + container: PluginServiceManagerContainer = PluginServiceManagerContainer() plugin_service = PluginServiceImpl( - container, props, target_func, driver_dialect_manager, driver_dialect) - plugin_manager: PluginManager = PluginManager(container, props, telemetry_factory) + container, props, target_func, driver_dialect_manager, driver_dialect, configuration_profile) + plugin_manager: PluginManager = PluginManager(container, props, telemetry_factory, configuration_profile) return AwsWrapperConnection(target_func, plugin_service, plugin_service, plugin_manager) except Exception as ex: diff --git a/aws_advanced_python_wrapper/writer_failover_handler.py b/aws_advanced_python_wrapper/writer_failover_handler.py index 600c24d2..fc0a22d2 100644 --- a/aws_advanced_python_wrapper/writer_failover_handler.py +++ b/aws_advanced_python_wrapper/writer_failover_handler.py @@ -116,8 +116,6 @@ def get_writer(self, topology: Tuple[HostInfo, ...]) -> Optional[HostInfo]: def get_result_from_future(self, current_topology: Tuple[HostInfo, ...]) -> WriterFailoverResult: writer_host: Optional[HostInfo] = self.get_writer(current_topology) if writer_host is not None: - self._plugin_service.set_availability(writer_host.as_aliases(), HostAvailability.UNAVAILABLE) - with ThreadPoolExecutor(thread_name_prefix="WriterFailoverHandlerExecutor") as executor: try: futures = [executor.submit(self.reconnect_to_writer, writer_host), diff --git a/benchmarks/plugin_benchmarks.py b/benchmarks/plugin_benchmarks.py index fb8a0c71..acaf8d7b 100644 --- a/benchmarks/plugin_benchmarks.py +++ b/benchmarks/plugin_benchmarks.py @@ -66,6 +66,11 @@ def driver_dialect_mock(mocker): return dialect +@pytest.fixture +def configuration_profile_mock(mocker): + return mocker.MagicMock() + + @pytest.fixture def plugin_service_mock(mocker, driver_dialect_mock): service_mock = mocker.MagicMock() @@ -82,44 +87,54 @@ def plugin_service_manager_container_mock(mocker, plugin_service_mock): @pytest.fixture -def plugin_manager_with_execute_time_plugin(plugin_service_manager_container_mock, props_with_execute_time_plugin): - manager: PluginManager = PluginManager(plugin_service_manager_container_mock, props_with_execute_time_plugin) +def plugin_manager_with_execute_time_plugin( + plugin_service_manager_container_mock, + props_with_execute_time_plugin, + configuration_profile_mock): + manager: PluginManager = PluginManager( + plugin_service_manager_container_mock, + props_with_execute_time_plugin, + configuration_profile_mock) return manager @pytest.fixture def plugin_manager_with_aurora_connection_tracker_plugin( - plugin_service_manager_container_mock, props_with_aurora_connection_tracker_plugin): - + plugin_service_manager_container_mock, props_with_aurora_connection_tracker_plugin, configuration_profile_mock): manager: PluginManager = PluginManager( - plugin_service_manager_container_mock, props_with_aurora_connection_tracker_plugin) + plugin_service_manager_container_mock, props_with_aurora_connection_tracker_plugin, configuration_profile_mock) return manager @pytest.fixture def plugin_manager_with_execute_time_and_aurora_connection_tracker_plugin( - plugin_service_manager_container_mock, props_with_execute_time_and_aurora_connection_tracker_plugin): - + plugin_service_manager_container_mock, + props_with_execute_time_and_aurora_connection_tracker_plugin, + configuration_profile_mock): manager: PluginManager = PluginManager( - plugin_service_manager_container_mock, props_with_execute_time_and_aurora_connection_tracker_plugin) + plugin_service_manager_container_mock, + props_with_execute_time_and_aurora_connection_tracker_plugin, + configuration_profile_mock) return manager @pytest.fixture def plugin_manager_with_read_write_splitting_plugin( - plugin_service_manager_container_mock, props_with_read_write_splitting_plugin): - + plugin_service_manager_container_mock, props_with_read_write_splitting_plugin, configuration_profile_mock): manager: PluginManager = PluginManager( - plugin_service_manager_container_mock, props_with_read_write_splitting_plugin) + plugin_service_manager_container_mock, props_with_read_write_splitting_plugin, configuration_profile_mock) return manager @pytest.fixture def plugin_manager_with_aurora_connection_tracker_and_read_write_splitting_plugin( - plugin_service_manager_container_mock, props_with_aurora_connection_tracker_and_read_write_splitting_plugin): - + plugin_service_manager_container_mock, + props_with_aurora_connection_tracker_and_read_write_splitting_plugin, + configuration_profile_mock): manager: PluginManager = PluginManager( - plugin_service_manager_container_mock, props_with_aurora_connection_tracker_and_read_write_splitting_plugin) + plugin_service_manager_container_mock, + props_with_aurora_connection_tracker_and_read_write_splitting_plugin, + configuration_profile_mock) return manager @@ -131,21 +146,18 @@ def init_and_release(mocker, plugin_service_mock, plugin_manager): def test_init_and_release_with_execution_time_plugin( benchmark, mocker, plugin_service_mock, plugin_manager_with_execute_time_plugin): - result = benchmark(init_and_release, mocker, plugin_service_mock, plugin_manager_with_execute_time_plugin) assert result is not None def test_init_and_release_with_aurora_connection_tracker_plugin( benchmark, mocker, plugin_service_mock, plugin_manager_with_aurora_connection_tracker_plugin): - result = benchmark(init_and_release, mocker, plugin_service_mock, plugin_manager_with_aurora_connection_tracker_plugin) assert result is not None def test_init_and_release_with_execute_time_and_aurora_connection_tracker_plugin( benchmark, mocker, plugin_service_mock, plugin_manager_with_execute_time_and_aurora_connection_tracker_plugin): - result = benchmark( init_and_release, mocker, plugin_service_mock, plugin_manager_with_execute_time_and_aurora_connection_tracker_plugin) assert result is not None @@ -153,14 +165,12 @@ def test_init_and_release_with_execute_time_and_aurora_connection_tracker_plugin def test_init_and_release_with_read_write_splitting_plugin( benchmark, mocker, plugin_service_mock, plugin_manager_with_read_write_splitting_plugin): - result = benchmark(init_and_release, mocker, plugin_service_mock, plugin_manager_with_read_write_splitting_plugin) assert result is not None def test_init_and_release_with_aurora_connection_tracker_and_read_write_splitting_plugin( benchmark, mocker, plugin_service_mock, plugin_manager_with_aurora_connection_tracker_and_read_write_splitting_plugin): - result = benchmark( init_and_release, mocker, @@ -184,7 +194,6 @@ def init_and_release_internal_connection_pools(mocker, plugin_service_mock, plug def test_init_and_release_with_read_write_splitting_plugin_internal_connection_pools( benchmark, mocker, plugin_service_mock, plugin_manager_with_read_write_splitting_plugin): - result = benchmark( init_and_release_internal_connection_pools, mocker, @@ -195,7 +204,6 @@ def test_init_and_release_with_read_write_splitting_plugin_internal_connection_p def test_init_and_release_with_aurora_connection_tracker_and_read_write_splitting_plugin_internal_connection_pools( benchmark, mocker, plugin_service_mock, plugin_manager_with_aurora_connection_tracker_and_read_write_splitting_plugin): - result = benchmark( init_and_release_internal_connection_pools, mocker, @@ -224,6 +232,5 @@ def execute_query(mocker, plugin_service, plugin_manager): def test_execute_query_with_execute_time_plugin( benchmark, mocker, plugin_service_mock, plugin_manager_with_execute_time_plugin): - result = benchmark(execute_query, mocker, plugin_service_mock, plugin_manager_with_execute_time_plugin) assert result is not None diff --git a/benchmarks/plugin_manager_benchmarks.py b/benchmarks/plugin_manager_benchmarks.py index 0ed9ae3a..65c17534 100644 --- a/benchmarks/plugin_manager_benchmarks.py +++ b/benchmarks/plugin_manager_benchmarks.py @@ -18,15 +18,18 @@ import pytest +from aws_advanced_python_wrapper.profiles.configuration_profile import \ + ConfigurationProfile + if TYPE_CHECKING: from aws_advanced_python_wrapper.driver_dialect import DriverDialect from aws_advanced_python_wrapper.plugin import PluginFactory -from aws_advanced_python_wrapper.driver_configuration_profiles import \ - DriverConfigurationProfiles from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.plugin_service import ( PluginManager, PluginServiceManagerContainer) +from aws_advanced_python_wrapper.profiles.driver_configuration_profiles import \ + DriverConfigurationProfiles from aws_advanced_python_wrapper.utils.properties import Properties from benchmarks.benchmark_plugin import BenchmarkPluginFactory @@ -43,7 +46,9 @@ def props_with_plugins(): factories: List[PluginFactory] = [] for _ in range(10): factories.append(BenchmarkPluginFactory()) - DriverConfigurationProfiles.add_or_replace_profile("benchmark", factories) + DriverConfigurationProfiles.add_or_replace_profile( + "benchmark", + ConfigurationProfile(name="benchmark", plugin_factories=factories)) return Properties({"profile_name": "benchmark"}) diff --git a/tests/unit/test_aurora_connection_tracker.py b/tests/unit/test_aurora_connection_tracker.py index c0847f43..4d6b4b6d 100644 --- a/tests/unit/test_aurora_connection_tracker.py +++ b/tests/unit/test_aurora_connection_tracker.py @@ -77,7 +77,7 @@ def props(): def test_track_new_instance_connection( - mocker, mock_plugin_service, mock_rds_utils, mock_tracker, mock_cursor, mock_callable): + mocker, mock_plugin_service, mock_rds_utils, mock_tracker, mock_cursor, mock_callable, mock_conn): host_info: HostInfo = HostInfo("instance1") mock_plugin_service.hosts = [host_info] mock_plugin_service.current_host_info = host_info @@ -116,4 +116,4 @@ def test_invalidate_opened_connections( plugin.execute(mock_cursor, "Cursor.execute", mock_callable, ("select 1", {})) mock_tracker.invalidate_current_connection.assert_not_called() - mock_tracker.invalidate_all_connections.assert_called_with(host_info=original_host) + mock_tracker.invalidate_all_connections.assert_called_with(original_host) diff --git a/tests/unit/test_plugin_manager.py b/tests/unit/test_plugin_manager.py index 9e380a65..9f25cad9 100644 --- a/tests/unit/test_plugin_manager.py +++ b/tests/unit/test_plugin_manager.py @@ -16,6 +16,8 @@ from typing import TYPE_CHECKING +from aws_advanced_python_wrapper import AwsWrapperConnection + if TYPE_CHECKING: from aws_advanced_python_wrapper.driver_dialect import DriverDialect from aws_advanced_python_wrapper.pep249 import Connection @@ -136,12 +138,11 @@ def test_sort_plugins_with_stick_to_prior(mocker): def test_unknown_profile(mocker, mock_telemetry_factory): - props = Properties(profile_name="unknown_profile") with pytest.raises(AwsWrapperError): - PluginManager(mocker.MagicMock(), props, mock_telemetry_factory()) + AwsWrapperConnection.connect(mocker.MagicMock(), "host=localhost wrapper_driver_dialect=mysql-connector-python profile_name=unknown_profile") -def test_execute_call_a(mocker, mock_conn, container, mock_driver_dialect, mock_telemetry_factory): +def test_execute_call_a(mocker, mock_conn, container, mock_driver_dialect, mock_telemetry_factory, mock_plugin_service): calls = [] args = [10, "arg2", 3.33] plugins = [TestPluginOne(calls), TestPluginTwo(calls), TestPluginThree(calls)] @@ -189,7 +190,7 @@ def _target_call(calls: List[str]): return "result_value" -def test_execute_call_b(mocker, container, mock_driver_dialect, mock_telemetry_factory): +def test_execute_call_b(mocker, container, mock_driver_dialect, mock_telemetry_factory, mock_conn): calls = [] args = [10, "arg2", 3.33] plugins = [TestPluginOne(calls), TestPluginTwo(calls), TestPluginThree(calls)] @@ -212,7 +213,7 @@ def test_execute_call_b(mocker, container, mock_driver_dialect, mock_telemetry_f assert calls[4] == "TestPluginOne:after execute" -def test_execute_call_c(mocker, container, mock_driver_dialect, mock_telemetry_factory): +def test_execute_call_c(mocker, container, mock_driver_dialect, mock_telemetry_factory, mock_conn): calls = [] args = [10, "arg2", 3.33] plugins = [TestPluginOne(calls), TestPluginTwo(calls), TestPluginThree(calls)] @@ -233,7 +234,7 @@ def test_execute_call_c(mocker, container, mock_driver_dialect, mock_telemetry_f assert calls[2] == "TestPluginOne:after execute" -def test_execute_against_old_target(mocker, container, mock_driver_dialect, mock_telemetry_factory): +def test_execute_against_old_target(mocker, container, mock_driver_dialect, mock_telemetry_factory, mock_conn): mocker.patch.object(PluginManager, "__init__", lambda w, x, y, z: None) manager = PluginManager(mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()) manager._container = container diff --git a/tests/unit/test_writer_failover_handler.py b/tests/unit/test_writer_failover_handler.py index ec28eb70..0844a09d 100644 --- a/tests/unit/test_writer_failover_handler.py +++ b/tests/unit/test_writer_failover_handler.py @@ -97,12 +97,12 @@ def reader_b(): @pytest.fixture def topology(writer, reader_a, reader_b): - return [writer, reader_a, reader_b] + return tuple((writer, reader_a, reader_b)) @pytest.fixture def new_topology(new_writer_host, reader_a, reader_b): - return [new_writer_host, reader_a, reader_b] + return tuple((new_writer_host, reader_a, reader_b)) @pytest.fixture(autouse=True) @@ -141,8 +141,7 @@ def force_connect_side_effect(host_info, _, __) -> Connection: assert not result.is_new_host assert result.new_connection is writer_connection_mock - expected = [call(writer.as_aliases(), HostAvailability.UNAVAILABLE), - call(writer.as_aliases(), HostAvailability.AVAILABLE)] + expected = [call(writer.as_aliases(), HostAvailability.AVAILABLE)] plugin_service_mock.set_availability.assert_has_calls(expected) @@ -152,8 +151,7 @@ def test_reconnect_to_writer_slow_task_b( reader_a_connection_mock, default_properties, new_writer_host, writer, reader_a, reader_b, topology, new_topology): exception = Exception("Test Exception") - expected = [call(writer.as_aliases(), HostAvailability.UNAVAILABLE), - call(writer.as_aliases(), HostAvailability.AVAILABLE)] + expected = [call(writer.as_aliases(), HostAvailability.AVAILABLE)] mock_hosts_property = mocker.PropertyMock(side_effect=chain([topology], cycle([new_topology]))) type(plugin_service_mock).hosts = mock_hosts_property @@ -224,8 +222,7 @@ def get_reader_connection_side_effect(_): assert not result.is_new_host assert result.new_connection is writer_connection_mock - expected = [call(writer.as_aliases(), HostAvailability.UNAVAILABLE), - call(writer.as_aliases(), HostAvailability.AVAILABLE)] + expected = [call(writer.as_aliases(), HostAvailability.AVAILABLE)] plugin_service_mock.set_availability.assert_has_calls(expected) @@ -270,8 +267,7 @@ def get_reader_connection_side_effect(_): assert result.is_new_host assert result.new_connection is new_writer_connection_mock - expected = [call(writer.as_aliases(), HostAvailability.UNAVAILABLE), - call(new_writer_host.as_aliases(), HostAvailability.AVAILABLE)] + expected = [call(new_writer_host.as_aliases(), HostAvailability.AVAILABLE)] plugin_service_mock.set_availability.assert_has_calls(expected) @@ -320,8 +316,7 @@ def get_reader_connection_side_effect(_): assert len(result.topology) == 4 assert "new-writer-host" == result.topology[0].host - expected = [call(writer.as_aliases(), HostAvailability.UNAVAILABLE), - call(new_writer_host.as_aliases(), HostAvailability.AVAILABLE)] + expected = [call(new_writer_host.as_aliases(), HostAvailability.AVAILABLE)] plugin_service_mock.set_availability.assert_has_calls(expected, any_order=True) plugin_service_mock.force_refresh_host_list.assert_called() @@ -376,7 +371,7 @@ def get_reader_connection_side_effect(_): assert not result.is_connected assert not result.is_new_host - expected = [call(writer.as_aliases(), HostAvailability.UNAVAILABLE)] + expected = [call(writer.as_aliases(), HostAvailability.AVAILABLE)] plugin_service_mock.set_availability.assert_has_calls(expected) plugin_service_mock.force_refresh_host_list.assert_called() @@ -419,7 +414,6 @@ def get_reader_connection_side_effect(_): assert not result.is_connected assert not result.is_new_host - expected = [call(writer.as_aliases(), HostAvailability.UNAVAILABLE), - call(new_writer_host.as_aliases(), HostAvailability.UNAVAILABLE)] + expected = [call(new_writer_host.as_aliases(), HostAvailability.UNAVAILABLE)] plugin_service_mock.set_availability.assert_has_calls(expected, any_order=True)