diff --git a/.github/workflows/failover.yml b/.github/workflows/failover.yml index de8c1e3df..05a7cdd13 100644 --- a/.github/workflows/failover.yml +++ b/.github/workflows/failover.yml @@ -1,6 +1,7 @@ name: Failover Unit Tests on: + workflow_dispatch: push: branches: - main diff --git a/docs/images/sample_custom_endpoints_dsn.png b/docs/images/sample_custom_endpoints_dsn.png new file mode 100644 index 000000000..b4baa558e Binary files /dev/null and b/docs/images/sample_custom_endpoints_dsn.png differ diff --git a/docs/using-the-aws-driver/CustomEndpoint.md b/docs/using-the-aws-driver/CustomEndpoint.md new file mode 100644 index 000000000..2dee5604f --- /dev/null +++ b/docs/using-the-aws-driver/CustomEndpoint.md @@ -0,0 +1,26 @@ +# Custom Endpoint Support + +The Custom Endpoint support allows client application to use the driver with [RDS custom endpoints](https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/Aurora.Endpoints.Custom.html). When the Custom Endpoint feature is enabled, the driver will analyse custom endpoint information to ensure instances used in connections are part of the custom endpoint being used. This includes connections used in failover. + +## How to use the Driver with Custom Endpoint + +### Enabling the Custom Endpoint Feature + +1. If needed, create a custom endpoint using the AWS RDS Console: + - If needed, review the documentation about [creating a custom endpoint](https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/aurora-custom-endpoint-creating.html). +2. Set `ENABLE_CUSTOM_ENDPOINT_MONITORING` to `TRUE` to enable custom endpoint support. +3. If you are using the failover plugin, set the failover parameter `FAILOVER_MODE` according to the custom endpoint type. For example, if the custom endpoint you are using is of type `READER`, you can set `FAILOVER_MODE` to `strict reader`, or if it is of type `ANY`, you can set `FAILOVER_MODE` to `reader or writer`. +4. Specify parameters that are required or specific to your case. + +### Custom Endpoint Plugin Parameters + +| Parameter | Value | Required | Description | Default Value | Example Value | +| ------------------------------------------ | :----: | :------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------c--------------------------------------------------------------------------------------------- | --------------------- | ------------- | +| `ENABLE_CUSTOM_ENDPOINT_MONITORING` | bool | No | Set to TRUE to enable custom endpoint support. | `FALSE` | `TRUE` | +| `CUSTOM_ENDPOINT_REGION` | string | No | The region of the cluster's custom endpoints. If not specified, the region will be parsed from the URL. | `N/A` | `us-west-1` | +| `CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS` | long | No | Controls how frequently custom endpoint monitors fetch custom endpoint info, in milliseconds. | `30000` | `20000` | +| `CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS` | long | No | Controls how long a monitor should run without use before expiring and being removed, in milliseconds. | `900000` (15 minutes) | `600000` | +| `WAIT_FOR_CUSTOM_ENDPOINT_INFO` | bool | No | Controls whether to wait for custom endpoint info to become available before connecting or executing a method. Waiting is only necessary if a connection to a given custom endpoint has not been opened or used recently. Note that disabling this may result in occasional connections to instances outside of the custom endpoint. | `TRUE` | `TRUE` | +| `WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS` | long | No | Controls the maximum amount of time that the plugin will wait for custom endpoint info to be made available by the custom endpoint monitor, in milliseconds. | `5000` | `7000` | + +![sample_custom_endpoints_dsn](../images/sample_custom_endpoints_dsn.png) diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt index 1c9a090f7..e0a5252cc 100644 --- a/driver/CMakeLists.txt +++ b/driver/CMakeLists.txt @@ -62,6 +62,7 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT}) auth_util.cc aws_sdk_helper.cc base_metrics_holder.cc + cache_map.cc catalog.cc catalog_no_i_s.cc cluster_topology_info.cc @@ -72,6 +73,9 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT}) connect.cc connection_handler.cc connection_proxy.cc + custom_endpoint_info.cc + custom_endpoint_monitor.cc + custom_endpoint_proxy.cc cursor.cc desc.cc dll.cc @@ -131,9 +135,11 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT}) CONFIGURE_FILE(${CMAKE_SOURCE_DIR}/driver/driver.rc.cmake ${CMAKE_SOURCE_DIR}/driver/driver${CONNECTOR_DRIVER_TYPE_SHORT}.rc @ONLY) SET(DRIVER_SRCS ${DRIVER_SRCS} driver${CONNECTOR_DRIVER_TYPE_SHORT}.def driver${CONNECTOR_DRIVER_TYPE_SHORT}.rc adfs_proxy.h + allowed_and_blocked_hosts.h auth_util.h aws_sdk_helper.h base_metrics_holder.h + cache_map.h catalog.h cluster_aware_hit_metrics_holder.h cluster_aware_metrics_container.h @@ -142,6 +148,9 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT}) cluster_topology_info.h connection_handler.h connection_proxy.h + custom_endpoint_info.h + custom_endpoint_monitor.h + custom_endpoint_proxy.h driver.h efm_proxy.h error.h diff --git a/driver/allowed_and_blocked_hosts.h b/driver/allowed_and_blocked_hosts.h new file mode 100644 index 000000000..f32036478 --- /dev/null +++ b/driver/allowed_and_blocked_hosts.h @@ -0,0 +1,74 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#ifndef __ALLOWED_AND_BLOCKED_HOSTS__ +#define __ALLOWED_AND_BLOCKED_HOSTS__ + +#include +#include + +/** + * Represents the allowed and blocked hosts for connections. + */ +class ALLOWED_AND_BLOCKED_HOSTS { + public: + /** + * Constructs an AllowedAndBlockedHosts instance with the specified allowed and blocked host IDs. + * @param allowed_host_ids The set of allowed host IDs for connections. If null or empty, all host IDs that are not in + * `blocked_host_ids` are allowed. + * @param blocked_host_ids The set of blocked host IDs for connections. If null or empty, all host IDs in + * `allowed_host_ids` are allowed. If `allowed_host_ids` is also null or empty, there + * are no restrictions on which hosts are allowed. + */ + ALLOWED_AND_BLOCKED_HOSTS(const std::set& allowed_host_ids, + const std::set& blocked_host_ids) + : allowed_host_ids(allowed_host_ids), blocked_host_ids(blocked_host_ids){}; + + /** + * Returns the set of allowed host IDs for connections. If null or empty, all host IDs that are not in + * `blocked_host_ids` are allowed. + * + * @return the set of allowed host IDs for connections. + */ + std::set get_allowed_host_ids() { return this->allowed_host_ids; }; + + /** + * Returns the set of blocked host IDs for connections. If null or empty, all host IDs in `allowed_host_ids` + * are allowed. If `allowed_host_ids` is also null or empty, there are no restrictions on which hosts are allowed. + * + * @return the set of blocked host IDs for connections. + */ + std::set get_blocked_host_ids() { return this->blocked_host_ids; }; + + private: + std::set allowed_host_ids; + std::set blocked_host_ids; +}; + +#endif diff --git a/driver/auth_util.cc b/driver/auth_util.cc index 7ff6aac47..0b5fe1f57 100644 --- a/driver/auth_util.cc +++ b/driver/auth_util.cc @@ -74,7 +74,7 @@ std::pair AUTH_UTIL::get_auth_token(std::unordered_mapbuild_cache_key(host, region, port, user); + const std::string cache_key = build_cache_key(host, region, port, user); bool using_cached_token = false; { diff --git a/driver/cache_map.cc b/driver/cache_map.cc new file mode 100644 index 000000000..faa47fe05 --- /dev/null +++ b/driver/cache_map.cc @@ -0,0 +1,95 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include "cache_map.h" + +#include + +#include "custom_endpoint_info.h" + +template +void CACHE_MAP::put(K key, V value, long long item_expiration_nanos) { + this->cache[key] = std::make_shared( + value, std::chrono::steady_clock::now() + std::chrono::nanoseconds(item_expiration_nanos)); + this->clean_up(); +} + +template +V CACHE_MAP::get(K key, V default_value) { + if (cache.count(key) > 0 && !cache[key]->is_expired()) { + return this->cache[key]->item; + } + return default_value; +} + +template +V CACHE_MAP::get(K key, V default_value, long long item_expiration_nanos) { + if (cache.count(key) == 0 || this->cache[key]->is_expired()) { + this->put(key, std::move(default_value), item_expiration_nanos); + } + return this->cache[key]->item; +} + +template +void CACHE_MAP::remove(K key) { + if (this->cache.count(key)) { + this->cache.erase(key); + } + this->clean_up(); +} + +template +int CACHE_MAP::size() { + return this->cache.size(); +} + +template +void CACHE_MAP::clear() { + this->cache.clear(); +} + +template +void CACHE_MAP::clean_up() { + if (std::chrono::steady_clock::now() > this->clean_up_time_nanos.load()) { + this->clean_up_time_nanos = + std::chrono::steady_clock::now() + std::chrono::nanoseconds(this->clean_up_time_interval_nanos); + std::vector keys; + keys.reserve(this->cache.size()); + for (auto& [key, cache_item] : this->cache) { + keys.push_back(key); + } + for (const auto& key : keys) { + if (this->cache[key]->is_expired()) { + this->cache.erase(key); + } + } + } +} + +template class CACHE_MAP>; diff --git a/driver/cache_map.h b/driver/cache_map.h new file mode 100644 index 000000000..f01adaecf --- /dev/null +++ b/driver/cache_map.h @@ -0,0 +1,74 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#ifndef __CACHE_MAP_H__ +#define __CACHE_MAP_H__ + +#include +#include +#include +#include + +template +class CACHE_MAP { + public: + class CACHE_ITEM { + public: + CACHE_ITEM() = default; + CACHE_ITEM(V item, std::chrono::steady_clock::time_point expiration_time) + : item(item), expiration_time(expiration_time){}; + ~CACHE_ITEM() = default; + V item; + + bool is_expired() { return std::chrono::steady_clock::now() > this->expiration_time; } + + private: + std::chrono::steady_clock::time_point expiration_time; + }; + + CACHE_MAP() = default; + ~CACHE_MAP() = default; + + void put(K key, V value, long long item_expiration_nanos); + V get(K key, V default_value); + V get(K key, V default_value, long long item_expiration_nanos); + void remove(K key); + int size(); + void clear(); + + protected: + void clean_up(); + const long long clean_up_time_interval_nanos = 60000000000; // 10 minute + std::atomic clean_up_time_nanos; + + private: + std::unordered_map> cache; +}; + +#endif diff --git a/driver/cluster_topology_info.cc b/driver/cluster_topology_info.cc index d33360d8a..25be5e1c0 100644 --- a/driver/cluster_topology_info.cc +++ b/driver/cluster_topology_info.cc @@ -30,6 +30,7 @@ #include "cluster_topology_info.h" #include +#include /** Initialize and return random number. @@ -75,6 +76,19 @@ void CLUSTER_TOPOLOGY_INFO::add_host(std::shared_ptr host_info) { update_time(); } +void CLUSTER_TOPOLOGY_INFO::remove_host(std::shared_ptr host_info) { + auto position = std::find(writers.begin(), writers.end(), host_info); + if (position != writers.end()) { + writers.erase(position); + } + + position = std::find(readers.begin(), readers.end(), host_info); + if (position != readers.end()) { + readers.erase(position); + } + update_time(); +} + size_t CLUSTER_TOPOLOGY_INFO::total_hosts() { return writers.size() + readers.size(); } @@ -136,6 +150,13 @@ std::vector> CLUSTER_TOPOLOGY_INFO::get_writers() { return writers; } +std::vector> CLUSTER_TOPOLOGY_INFO::get_instances() { + std::vector instances(writers); + instances.insert(instances.end(), readers.begin(), readers.end()); + + return instances; +} + std::shared_ptr CLUSTER_TOPOLOGY_INFO::get_last_used_reader() { return last_used_reader; } diff --git a/driver/cluster_topology_info.h b/driver/cluster_topology_info.h index 90d840370..1e7271ee7 100644 --- a/driver/cluster_topology_info.h +++ b/driver/cluster_topology_info.h @@ -46,6 +46,7 @@ class CLUSTER_TOPOLOGY_INFO { virtual ~CLUSTER_TOPOLOGY_INFO(); void add_host(std::shared_ptr host_info); + void remove_host(std::shared_ptr host_info); size_t total_hosts(); size_t num_readers(); // return number of readers in the cluster std::time_t time_last_updated(); @@ -58,6 +59,7 @@ class CLUSTER_TOPOLOGY_INFO { std::shared_ptr get_reader(int i); std::vector> get_writers(); std::vector> get_readers(); + std::vector> get_instances(); private: int current_reader = -1; diff --git a/driver/connection_handler.cc b/driver/connection_handler.cc index 5e1e388da..8a68547f1 100644 --- a/driver/connection_handler.cc +++ b/driver/connection_handler.cc @@ -83,6 +83,9 @@ CONNECTION_PROXY* CONNECTION_HANDLER::connect(std::shared_ptr host_in } my_SQLFreeConnect(dbc_clone); + if (new_connection != nullptr) { + new_connection->set_dbc(dbc); + } return new_connection; } diff --git a/driver/connection_proxy.h b/driver/connection_proxy.h index d4f44674b..b28d72f4d 100644 --- a/driver/connection_proxy.h +++ b/driver/connection_proxy.h @@ -177,6 +177,8 @@ class CONNECTION_PROXY { void set_custom_error_message(const char* error_message); + void set_dbc(DBC* dbc) { this->dbc = dbc; }; + protected: DBC* dbc = nullptr; DataSource* ds = nullptr; diff --git a/driver/custom_endpoint_info.cc b/driver/custom_endpoint_info.cc new file mode 100644 index 000000000..bb8bc0c9b --- /dev/null +++ b/driver/custom_endpoint_info.cc @@ -0,0 +1,84 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software {} you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY {} without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include "custom_endpoint_info.h" + +std::shared_ptr CUSTOM_ENDPOINT_INFO::from_db_cluster_endpoint( + const Aws::RDS::Model::DBClusterEndpoint& response_endpoint_info) { + std::vector members; + MEMBERS_LIST_TYPE members_list_type; + + if (response_endpoint_info.StaticMembersHasBeenSet()) { + members = response_endpoint_info.GetStaticMembers(); + members_list_type = STATIC_LIST; + } else { + members = response_endpoint_info.GetExcludedMembers(); + members_list_type = EXCLUSION_LIST; + } + + std::set members_set(members.begin(), members.end()); + + return std::make_shared( + response_endpoint_info.GetDBClusterEndpointIdentifier(), response_endpoint_info.GetDBClusterIdentifier(), + response_endpoint_info.GetEndpoint(), + CUSTOM_ENDPOINT_INFO::get_role_type(response_endpoint_info.GetCustomEndpointType()), members_set, + members_list_type); +} + +std::set CUSTOM_ENDPOINT_INFO::get_excluded_members() const { + if (this->member_list_type == EXCLUSION_LIST) { + return members; + } + + return std::set(); +} + +std::set CUSTOM_ENDPOINT_INFO::get_static_members() const { + if (this->member_list_type == STATIC_LIST) { + return members; + } + + return std::set(); +} + +bool operator==(const CUSTOM_ENDPOINT_INFO& current, const CUSTOM_ENDPOINT_INFO& other) { + return current.endpoint_identifier == other.endpoint_identifier && + current.cluster_identifier == other.cluster_identifier && current.url == other.url && + current.role_type == other.role_type && + current.member_list_type == other.member_list_type; +} + +CUSTOM_ENDPOINT_ROLE_TYPE CUSTOM_ENDPOINT_INFO::get_role_type(const Aws::String& role_type) { + auto it = CUSTOM_ENDPOINT_ROLE_TYPE_MAP.find(role_type); + if (it != CUSTOM_ENDPOINT_ROLE_TYPE_MAP.end()) { + return it->second; + } + + throw std::invalid_argument("Invalid role type for custom endpoint, this should not have happened."); +} diff --git a/driver/custom_endpoint_info.h b/driver/custom_endpoint_info.h new file mode 100644 index 000000000..48f68b394 --- /dev/null +++ b/driver/custom_endpoint_info.h @@ -0,0 +1,134 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#ifndef __CUSTOM_ENDPOINT_INFO_H__ +#define __CUSTOM_ENDPOINT_INFO_H__ + +#include + +#include +#include "stringutil.h" +#include "mylog.h" + +/** + * Enum representing the possible roles of instances specified by a custom endpoint. Note that, currently, it is not + * possible to create a WRITER custom endpoint. + */ +enum CUSTOM_ENDPOINT_ROLE_TYPE { + ANY, // Instances in the custom endpoint may be either a writer or a reader. + WRITER, // Instance in the custom endpoint is always the writer. + READER // Instances in the custom endpoint are always readers. +}; + +static std::unordered_map const CUSTOM_ENDPOINT_ROLE_TYPE_MAP = { + {"ANY", ANY}, {"WRITER", WRITER}, {"READER", READER}}; + +static std::unordered_map const CUSTOM_ENDPOINT_ROLE_TYPE_STR_MAP = { + {ANY, "ANY"}, {WRITER, "WRITER"}, {READER, "READER"}}; + +/** + * Enum representing the member list type of a custom endpoint. This information can be used together with a member list + * to determine which instances are included or excluded from a custom endpoint. + */ +enum MEMBERS_LIST_TYPE { + /** + * The member list for the custom endpoint specifies which instances are included in the custom endpoint. If new + * instances are added to the cluster, they will not be automatically added to the custom endpoint. + */ + STATIC_LIST, + /** + * The member list for the custom endpoint specifies which instances are excluded from the custom endpoint. If new + * instances are added to the cluster, they will be automatically added to the custom endpoint. + */ + EXCLUSION_LIST +}; + +static std::unordered_map const MEMBERS_LIST_TYPE_MAP = { + {STATIC_LIST, "STATIC_LIST"}, {EXCLUSION_LIST, "EXCLUSION_LIST"}}; + +class CUSTOM_ENDPOINT_INFO { + public: + CUSTOM_ENDPOINT_INFO(std::string endpoint_identifier, std::string cluster_identifier, std::string url, + CUSTOM_ENDPOINT_ROLE_TYPE role_type, std::set members, + MEMBERS_LIST_TYPE member_list_type) + : endpoint_identifier(std::move(endpoint_identifier)), + cluster_identifier(std::move(cluster_identifier)), + url(std::move(url)), + role_type(role_type), + members(std::move(members)), + member_list_type(member_list_type){}; + ~CUSTOM_ENDPOINT_INFO() = default; + + static std::shared_ptr from_db_cluster_endpoint( + const Aws::RDS::Model::DBClusterEndpoint& response_endpoint_info); + std::string get_endpoint_identifier() const { return this->endpoint_identifier; }; + std::string get_cluster_identifier() const { return this->cluster_identifier; }; + std::string get_url() const { return this->url; }; + CUSTOM_ENDPOINT_ROLE_TYPE get_custom_endpoint_type() const { return this->role_type; }; + MEMBERS_LIST_TYPE get_member_list_type() const { return this->member_list_type; }; + std::set get_excluded_members() const; + std::set get_static_members() const; + + std::string to_string() const { + char buf[4096]; + std::string members_list; + + for (auto const& m : members) { + members_list += m; + members_list += ","; + } + if (members_list.empty()) { + members_list = ""; + } else { + members_list.pop_back(); + } + + myodbc_snprintf( + buf, sizeof(buf), + "CustomEndpointInfo[url=%s, cluster_identifier=%s, custom_endpoint_type=%s, member_list_type=%s, members=[%s]", + this->url.c_str(), this->cluster_identifier.c_str(), + CUSTOM_ENDPOINT_ROLE_TYPE_STR_MAP.at(this->role_type).c_str(), + MEMBERS_LIST_TYPE_MAP.at(this->member_list_type).c_str(), members_list.c_str()); + + return std::string(buf); + } + + friend bool operator==(const CUSTOM_ENDPOINT_INFO& current, const CUSTOM_ENDPOINT_INFO& other); + + private: + const std::string endpoint_identifier; + const std::string cluster_identifier; + const std::string url; + const CUSTOM_ENDPOINT_ROLE_TYPE role_type; + const std::set members; + const MEMBERS_LIST_TYPE member_list_type; + static CUSTOM_ENDPOINT_ROLE_TYPE get_role_type(const Aws::String& role_type); +}; + +#endif diff --git a/driver/custom_endpoint_monitor.cc b/driver/custom_endpoint_monitor.cc new file mode 100644 index 000000000..f715aa27d --- /dev/null +++ b/driver/custom_endpoint_monitor.cc @@ -0,0 +1,207 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include +#include +#include +#include +#include +#include + +#include "allowed_and_blocked_hosts.h" +#include "aws_sdk_helper.h" +#include "custom_endpoint_monitor.h" +#include "driver.h" +#include "mylog.h" + +namespace { +AWS_SDK_HELPER SDK_HELPER; +} + +CACHE_MAP> CUSTOM_ENDPOINT_MONITOR::custom_endpoint_cache; + +CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr topology_service, + const std::string& custom_endpoint_host, + const std::string& endpoint_identifier, const std::string& region, + long long refresh_rate_nanos, ctpl::thread_pool& thread_pool, + bool enable_logging) + : topology_service(topology_service), + custom_endpoint_host(custom_endpoint_host), + endpoint_identifier(endpoint_identifier), + region(region), + refresh_rate_nanos(refresh_rate_nanos), + thread_pool(thread_pool), + enable_logging(enable_logging) { + if (enable_logging) { + this->logger = init_log_file(); + } + + this->run(); +} + +#ifdef UNIT_TEST_BUILD +CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr topology_service, + const std::string& custom_endpoint_host, + const std::string& endpoint_identifier, const std::string& region, + long long refresh_rate_nanos, ctpl::thread_pool& thread_pool, + bool enable_logging, std::shared_ptr client) + : topology_service(topology_service), + custom_endpoint_host(custom_endpoint_host), + endpoint_identifier(endpoint_identifier), + region(region), + refresh_rate_nanos(refresh_rate_nanos), + thread_pool(thread_pool), + enable_logging(enable_logging) { + if (enable_logging) { + this->logger = init_log_file(); + } + this->run(); +} +#endif + +bool CUSTOM_ENDPOINT_MONITOR::should_dispose() { return true; } + +bool CUSTOM_ENDPOINT_MONITOR::has_custom_endpoint_info() const { + auto default_val = std::shared_ptr(nullptr); + return custom_endpoint_cache.get(this->custom_endpoint_host, default_val) != default_val; +} + +void CUSTOM_ENDPOINT_MONITOR::run() { + if (thread_pool.size() == 1) { + // Each monitor should only have 1 thread. + return; + } + MYLOG_TRACE(this->logger, 0, "Starting custom endpoint monitor for '%s'", this->custom_endpoint_host.c_str()); + thread_pool.resize(1); + thread_pool.push([=](int id) { + ++SDK_HELPER; + + Aws::RDS::RDSClientConfiguration client_config; + if (!region.empty()) { + client_config.region = region; + } + + const Aws::RDS::RDSClient rds_client(Aws::Auth::DefaultAWSCredentialsProviderChain().GetAWSCredentials(), + client_config); + Aws::RDS::Model::Filter filter; + filter.SetName("db-cluster-endpoint-type"); + filter.AddValues("custom"); + + Aws::RDS::Model::DescribeDBClusterEndpointsRequest request; + request.SetDBClusterEndpointIdentifier(this->endpoint_identifier); + // TODO: Investigate why filters returns `InvalidParameterCombination` error saying filter values are null. + // request.AddFilters(filter); + try { + while (!should_stop) { + const std::chrono::time_point start = std::chrono::steady_clock::now(); + const auto response = rds_client.DescribeDBClusterEndpoints(request); + const auto custom_endpoints = response.GetResult().GetDBClusterEndpoints(); + if (custom_endpoints.size() != 1) { + MYLOG_TRACE(this->logger, 0, + "Unexpected number of custom endpoints with endpoint identifier '%s' in region '%s'. Expected 1 " + "custom endpoint, but found %d. Endpoints: %s", + endpoint_identifier.c_str(), region.c_str(), custom_endpoints.size(), + this->get_endpoints_as_string(custom_endpoints).c_str()); + + std::this_thread::sleep_for(std::chrono::nanoseconds(this->refresh_rate_nanos)); + continue; + } + const std::shared_ptr endpoint_info = + CUSTOM_ENDPOINT_INFO::from_db_cluster_endpoint(custom_endpoints[0]); + const std::shared_ptr cache_endpoint_info = + custom_endpoint_cache.get(this->custom_endpoint_host, nullptr); + + if (cache_endpoint_info != nullptr && cache_endpoint_info == endpoint_info) { + const long long elapsed_time = + std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count(); + std::this_thread::sleep_for( + std::chrono::nanoseconds(std::max(static_cast(0), this->refresh_rate_nanos - elapsed_time))); + continue; + } + + MYLOG_TRACE(this->logger, 0, "Detected change in custom endpoint info for '%s':\n{%s}", + custom_endpoint_host.c_str(), endpoint_info->to_string().c_str()); + + // The custom endpoint info has changed, so we need to update the set of allowed/blocked hosts. + std::shared_ptr allowed_and_blocked_hosts; + if (endpoint_info->get_member_list_type() == STATIC_LIST) { + allowed_and_blocked_hosts = + std::make_shared(endpoint_info->get_static_members(), std::set()); + } else { + allowed_and_blocked_hosts = std::make_shared( + std::set(), endpoint_info->get_excluded_members()); + } + + this->topology_service->set_allowed_and_blocked_hosts(allowed_and_blocked_hosts); + custom_endpoint_cache.put(this->custom_endpoint_host, endpoint_info, CUSTOM_ENDPOINT_INFO_EXPIRATION_NANOS); + const long long elapsed_time = + std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count(); + std::this_thread::sleep_for( + std::chrono::nanoseconds(std::max(static_cast(0), this->refresh_rate_nanos - elapsed_time))); + } + + --SDK_HELPER; + } catch (const std::exception& e) { + // Log and continue monitoring. + --SDK_HELPER; + MYLOG_TRACE(this->logger, 0, "Error while monitoring custom endpoint: %s", e.what()); + } + + should_stop = true; + }); +} + +std::string CUSTOM_ENDPOINT_MONITOR::get_endpoints_as_string( + const std::vector& custom_endpoints) { + if (custom_endpoints.empty()) { + return ""; + } + + std::string endpoints("["); + + for (auto const& e : custom_endpoints) { + endpoints += e.GetDBClusterEndpointIdentifier(); + endpoints += ","; + } + + endpoints.pop_back(); + endpoints += "]"; + + return endpoints; +} + +void CUSTOM_ENDPOINT_MONITOR::stop() { + should_stop = true; + thread_pool.stop(true); + thread_pool.resize(0); + custom_endpoint_cache.remove(this->custom_endpoint_host); + MYLOG_TRACE(this->logger, 0, "Stopped custom endpoint monitor for '%s'", this->custom_endpoint_host.c_str()); +} + +void CUSTOM_ENDPOINT_MONITOR::clear_cache() { custom_endpoint_cache.clear(); } diff --git a/driver/custom_endpoint_monitor.h b/driver/custom_endpoint_monitor.h new file mode 100644 index 000000000..01843de4d --- /dev/null +++ b/driver/custom_endpoint_monitor.h @@ -0,0 +1,83 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#ifndef __CUSTOM_ENDPOINT_MONITOR_H__ +#define __CUSTOM_ENDPOINT_MONITOR_H__ + +#include + +#include +#include "cache_map.h" +#include "custom_endpoint_info.h" +#include "host_info.h" +#include "topology_service.h" + +class CUSTOM_ENDPOINT_MONITOR : public std::enable_shared_from_this { + public: + CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr topology_service, + const std::string& custom_endpoint_host, const std::string& endpoint_identifier, + const std::string& region, long long refresh_rate_nanos, ctpl::thread_pool& thread_pool, + bool enable_logging = false); +#ifdef UNIT_TEST_BUILD + CUSTOM_ENDPOINT_MONITOR(ctpl::thread_pool& pool): thread_pool(pool){}; + CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr topology_service, + const std::string& custom_endpoint_host, const std::string& endpoint_identifier, + const std::string& region, long long refresh_rate_nanos, ctpl::thread_pool& thread_pool, + bool enable_logging, std::shared_ptr client); +#endif + + static bool should_dispose(); + bool has_custom_endpoint_info() const; + void stop(); + void run(); + static void clear_cache(); + + protected: + static CACHE_MAP> custom_endpoint_cache; + static constexpr long long CUSTOM_ENDPOINT_INFO_EXPIRATION_NANOS = 300000000000; // 5 minutes + std::string custom_endpoint_host; + std::string endpoint_identifier; + std::string region; + long long refresh_rate_nanos; + bool enable_logging; + std::shared_ptr logger; + ctpl::thread_pool& thread_pool; + bool should_stop = false; + std::shared_ptr topology_service; + + private: + static std::string get_endpoints_as_string(const std::vector& custom_endpoints); + +#ifdef UNIT_TEST_BUILD + // Allows for testing private/protected methods + friend class TEST_UTILS; +#endif +}; + +#endif diff --git a/driver/custom_endpoint_proxy.cc b/driver/custom_endpoint_proxy.cc new file mode 100644 index 000000000..0a9b808f8 --- /dev/null +++ b/driver/custom_endpoint_proxy.cc @@ -0,0 +1,171 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include "custom_endpoint_proxy.h" +#include "mylog.h" +#include "rds_utils.h" + +SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD> + CUSTOM_ENDPOINT_PROXY::monitors(std::make_shared(), + std::make_shared(), CACHE_CLEANUP_RATE_NANO); + +bool CUSTOM_ENDPOINT_PROXY::is_monitor_cache_initialized(false); + +CUSTOM_ENDPOINT_PROXY::CUSTOM_ENDPOINT_PROXY(DBC* dbc, DataSource* ds) : CUSTOM_ENDPOINT_PROXY(dbc, ds, nullptr) {} +CUSTOM_ENDPOINT_PROXY::CUSTOM_ENDPOINT_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy) + : CONNECTION_PROXY(dbc, ds) { + this->next_proxy = next_proxy; + this->topology_service = dbc->get_topology_service(); + + if (ds->opt_LOG_QUERY) { + this->logger = init_log_file(); + } + this->should_wait_for_info = ds->opt_WAIT_FOR_CUSTOM_ENDPOINT_INFO; + this->wait_on_cached_info_duration_ms = ds->opt_WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS; + this->idle_monitor_expiration_ms = ds->opt_CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS; + + if (!is_monitor_cache_initialized) { + monitors.init_clean_up_thread(); + is_monitor_cache_initialized = true; + } +} + +bool CUSTOM_ENDPOINT_PROXY::connect(const char* host, const char* user, const char* password, const char* database, + unsigned int port, const char* socket, unsigned long flags) { + if (!RDS_UTILS::is_rds_custom_cluster_dns(host)) { + return this->next_proxy->connect(host, user, password, database, port, socket, flags); + } + + this->custom_endpoint_host = host; + MYLOG_TRACE(this->logger, 0, "Detected a connection request to a custom endpoint URL: '%s'", host); + + this->custom_endpoint_id = RDS_UTILS::get_rds_cluster_id(host); + + if (this->custom_endpoint_id.empty()) { + this->set_custom_error_message("Unable to parse custom endpoint identifier from URL."); + return false; + } + + this->region = ds->opt_CUSTOM_ENDPOINT_REGION ? static_cast(ds->opt_CUSTOM_ENDPOINT_REGION) + : RDS_UTILS::get_rds_region(host); + if (this->region.empty()) { + this->set_custom_error_message( + "Unable to determine connection region. If you are using a non-standard RDS URL, please set the " + "'custom_endpoint_region' property"); + return false; + } + + const std::shared_ptr monitor = create_monitor_if_absent(ds); + if (this->should_wait_for_info) { + // If needed, wait a short time for custom endpoint info to be discovered. + this->wait_for_custom_endpoint_info(monitor); + } + return this->next_proxy->connect(host, user, password, database, port, socket, flags); +} + +int CUSTOM_ENDPOINT_PROXY::query(const char* q) { + if (!this->custom_endpoint_host.empty()) { + const std::shared_ptr monitor = create_monitor_if_absent(ds); + if (this->should_wait_for_info) { + // If needed, wait a short time for custom endpoint info to be discovered. + this->wait_for_custom_endpoint_info(monitor); + } + } + + return next_proxy->query(q); +} + +int CUSTOM_ENDPOINT_PROXY::real_query(const char* q, unsigned long length) { + if (!this->custom_endpoint_host.empty()) { + const std::shared_ptr monitor = create_monitor_if_absent(ds); + if (this->should_wait_for_info) { + // If needed, wait a short time for custom endpoint info to be discovered. + this->wait_for_custom_endpoint_info(monitor); + } + } + + return next_proxy->real_query(q, length); +} + +void CUSTOM_ENDPOINT_PROXY::wait_for_custom_endpoint_info(std::shared_ptr monitor) { + bool has_custom_endpoint_info = monitor->has_custom_endpoint_info(); + + if (has_custom_endpoint_info) { + return; + } + + // Wait for the monitor to place the custom endpoint info in the cache. This ensures other plugins get accurate + // custom endpoint info. + MYLOG_TRACE(this->logger, 0, + "Custom endpoint info for '%s' was not found. Waiting %dms for the endpoint monitor to fetch info...", + this->custom_endpoint_host.c_str(), this->wait_on_cached_info_duration_ms) + + const auto wait_for_endpoint_info_timeout_nanos = + std::chrono::steady_clock::now() + std::chrono::milliseconds(this->wait_on_cached_info_duration_ms); + + while (!has_custom_endpoint_info && std::chrono::steady_clock::now() < wait_for_endpoint_info_timeout_nanos) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + has_custom_endpoint_info = monitor->has_custom_endpoint_info(); + } + + if (!has_custom_endpoint_info) { + char buf[1024]; + myodbc_snprintf( + buf, sizeof(buf), + "The custom endpoint plugin timed out after %ld ms while waiting for custom endpoint info for host %s.", + this->wait_on_cached_info_duration_ms, this->custom_endpoint_host.c_str()); + + set_custom_error_message(buf); + } +} + +std::shared_ptr CUSTOM_ENDPOINT_PROXY::create_custom_endpoint_monitor( + const long long refresh_rate_nanos) { + return std::make_shared(this->topology_service, this->custom_endpoint_host, + this->custom_endpoint_id, this->region, refresh_rate_nanos, + this->dbc->env->custom_endpoint_thread_pool); +} + +std::shared_ptr CUSTOM_ENDPOINT_PROXY::create_monitor_if_absent(DataSource* ds) { + const long long refresh_rate_nanos = std::chrono::duration_cast( + std::chrono::milliseconds(ds->opt_CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS)) + .count(); + + return monitors.compute_if_absent( + this->custom_endpoint_host, + [=](std::string key) { return this->create_custom_endpoint_monitor(refresh_rate_nanos); }, + std::chrono::duration_cast(std::chrono::milliseconds(this->idle_monitor_expiration_ms)) + .count()); +} + +void CUSTOM_ENDPOINT_PROXY::release_resources() { + if (!monitors.empty()) { + monitors.release_resources(); + } +} diff --git a/driver/custom_endpoint_proxy.h b/driver/custom_endpoint_proxy.h new file mode 100644 index 000000000..ed94df552 --- /dev/null +++ b/driver/custom_endpoint_proxy.h @@ -0,0 +1,100 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a scopy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#ifndef __CUSTOM_ENDPOINT_PROXY_H__ +#define __CUSTOM_ENDPOINT_PROXY_H__ + +#include +#include "connection_proxy.h" +#include "custom_endpoint_monitor.h" +#include "driver.h" +#include "sliding_expiration_cache_with_clean_up_thread.h" + +class CUSTOM_ENDPOINT_PROXY : public CONNECTION_PROXY { + public: + CUSTOM_ENDPOINT_PROXY(DBC* dbc, DataSource* ds); + CUSTOM_ENDPOINT_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy); + + bool connect(const char* host, const char* user, const char* password, const char* database, unsigned int port, + const char* socket, unsigned long flags) override; + + int query(const char* q) override; + int real_query(const char* q, unsigned long length) override; + + static void release_resources(); + + class CUSTOM_ENDPOINTS_SHOULD_DISPOSE_FUNC : public SHOULD_DISPOSE_FUNC> { + public: + bool should_dispose(std::shared_ptr item) override { return true; } + }; + + class CUSTOM_ENDPOINTS_ITEM_DISPOSAL_FUNC : public ITEM_DISPOSAL_FUNC> { + public: + void dispose(const std::shared_ptr monitor) override { + try { + monitor->stop(); + } catch (const std::exception& e) { + // Ignore + } + } + }; + static constexpr long long CACHE_CLEANUP_RATE_NANO = 60000000000; + + protected: + static bool is_monitor_cache_initialized; + std::string custom_endpoint_id; + std::string region; + std::string custom_endpoint_host; + std::shared_ptr rds_client; + bool should_wait_for_info; + long wait_on_cached_info_duration_ms; + long idle_monitor_expiration_ms; + std::shared_ptr topology_service; + + static SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD> monitors; + + std::shared_ptr create_monitor_if_absent(DataSource* ds); + + /** + * If custom endpoint info does not exist for the current custom endpoint, waits a short time for the info to be + * made available by the custom endpoint monitor. + * Since custom endpoint monitors and information are shared, we should not have to wait often. + */ + void wait_for_custom_endpoint_info(std::shared_ptr monitor); + + private: + std::shared_ptr logger; + virtual std::shared_ptr create_custom_endpoint_monitor(long long refresh_rate_nanos); +#ifdef UNIT_TEST_BUILD + // Allows for testing private/protected methods + friend class TEST_UTILS; +#endif +}; + +#endif diff --git a/driver/driver.h b/driver/driver.h index c41caad62..12c87a1fa 100644 --- a/driver/driver.h +++ b/driver/driver.h @@ -47,6 +47,7 @@ #include "connection_handler.h" #include "connection_proxy.h" +#include "topology_service.h" #include "failover.h" /* Disable _attribute__ on non-gcc compilers. */ @@ -568,6 +569,7 @@ struct ENV MYERROR error; std::mutex lock; ctpl::thread_pool failover_thread_pool; + ctpl::thread_pool custom_endpoint_thread_pool; ENV(SQLINTEGER ver) : odbc_ver(ver) {} @@ -627,6 +629,7 @@ struct DBC FAILOVER_HANDLER *fh = nullptr; /* Failover handler */ std::shared_ptr connection_handler = nullptr; + std::shared_ptr topology_service = nullptr; DBC(ENV *p_env); void free_explicit_descriptors(); @@ -639,6 +642,10 @@ struct DBC void execute_prep_stmt(MYSQL_STMT *pstmt, std::string &query, std::vector ¶m_bind, MYSQL_BIND *result_bind); void init_proxy_chain(DataSource *dsrc); + std::shared_ptr get_topology_service() { + return this->topology_service ? this->topology_service + : std::make_shared(this->id, ds ? ds->opt_LOG_QUERY : false); + } inline bool transactions_supported() { return connection_proxy->get_server_capabilities() & CLIENT_TRANSACTIONS; diff --git a/driver/failover_handler.cc b/driver/failover_handler.cc index f3233ac4d..fbb705e75 100644 --- a/driver/failover_handler.cc +++ b/driver/failover_handler.cc @@ -52,8 +52,7 @@ const char* MYSQL_READONLY_QUERY = "SELECT @@innodb_read_only AS is_reader"; FAILOVER_HANDLER::FAILOVER_HANDLER(DBC* dbc, DataSource* ds) : FAILOVER_HANDLER( - dbc, ds, dbc ? dbc->connection_handler : nullptr, - std::make_shared(dbc ? dbc->id : 0, ds ? ds->opt_LOG_QUERY : false), + dbc, ds, dbc ? dbc->connection_handler : nullptr, dbc ? dbc->get_topology_service() : nullptr, std::make_shared(dbc, ds)) {} FAILOVER_HANDLER::FAILOVER_HANDLER(DBC* dbc, DataSource* ds, @@ -431,6 +430,7 @@ void FAILOVER_HANDLER::initialize_topology() { MYLOG_DBC_TRACE(dbc, "[FAILOVER_HANDLER] m_is_cluster_topology_available=%s", m_is_cluster_topology_available ? "true" : "false"); + MYLOG_DBC_TRACE(dbc, topology_service->log_topology(current_topology).c_str()); if (is_failover_enabled()) { this->dbc->env->failover_thread_pool.resize(current_topology->total_hosts()); } @@ -505,13 +505,13 @@ bool FAILOVER_HANDLER::trigger_failover_if_needed(const char* error_code, } bool FAILOVER_HANDLER::failover_to_reader(const char*& new_error_code, const char*& error_msg) { - MYLOG_DBC_TRACE(dbc, "[FAILOVER_HANDLER] Starting reader failover procedure."); - auto result = failover_reader_handler->failover(current_topology); + MYLOG_DBC_TRACE(dbc, "[FAILOVER_HANDLER] Starting reader failover procedure with filtered topology: %s", this->topology_service->log_topology(this->topology_service->get_filtered_topology(current_topology)).c_str()); + auto result = failover_reader_handler->failover(this->topology_service->get_filtered_topology(current_topology)); if (result->connected) { current_host = result->new_host; connection_handler->update_connection(result->new_connection, current_host->get_host()); - new_error_code = "08S02"; + new_error_code = "08S02"; // Failover succeeded error code. error_msg = "The active SQL connection has changed."; MYLOG_DBC_TRACE(dbc, "[FAILOVER_HANDLER] The active SQL connection has changed " @@ -520,7 +520,7 @@ bool FAILOVER_HANDLER::failover_to_reader(const char*& new_error_code, const cha return true; } else { MYLOG_DBC_TRACE(dbc, "[FAILOVER_HANDLER] Unable to establish SQL connection to reader node."); - new_error_code = "08S01"; + new_error_code = "08S01"; // Failover failed error code. error_msg = "The active SQL connection was lost."; return false; } @@ -537,16 +537,32 @@ bool FAILOVER_HANDLER::failover_to_writer(const char*& new_error_code, const cha error_msg = "The active SQL connection was lost."; return false; } + + const auto new_topology = result->new_topology; + const auto new_host = new_topology->get_writer(); + if (result->is_new_host) { // connected to a new writer host; take it over - current_topology = result->new_topology; - current_host = current_topology->get_writer(); + current_topology = new_topology; + current_host = new_host; + } + const auto filtered_topology = this->topology_service->get_filtered_topology(new_topology); + const auto allowed_hosts = filtered_topology->get_instances(); + if (std::find(allowed_hosts.begin(), allowed_hosts.end(), new_host) == allowed_hosts.end()) { + new_error_code = "08S01"; // Failover failed error code. + error_msg = "The active SQL connection was lost."; + MYLOG_DBC_TRACE( + dbc, + "[FAILOVER_HANDLER] The failover process identified the new writer but the host is not in the list of allowed hosts. " + "New writer host: '%s'. Allowed hosts: '%s'", + new_host->get_host().c_str(), + this->topology_service->log_topology(filtered_topology).c_str()); + return false; } - connection_handler->update_connection( - result->new_connection, result->new_topology->get_writer()->get_host()); + connection_handler->update_connection(result->new_connection, new_host->get_host()); - new_error_code = "08S02"; + new_error_code = "08S02"; // Failover succeeded error code. error_msg = "The active SQL connection has changed."; MYLOG_DBC_TRACE( dbc, diff --git a/driver/handle.cc b/driver/handle.cc index 5f6386fd1..17f380dc5 100644 --- a/driver/handle.cc +++ b/driver/handle.cc @@ -57,6 +57,8 @@ #include +#include "custom_endpoint_proxy.h" + thread_local long thread_count = 0; std::mutex g_lock; @@ -100,7 +102,6 @@ void DBC::remove_desc(DESC* desc) desc_list.remove(desc); } - void DBC::free_explicit_descriptors() { @@ -121,7 +122,9 @@ void DBC::close() // construct a proxy chain, example: iam->efm->mysql void DBC::init_proxy_chain(DataSource* dsrc) { - CONNECTION_PROXY *head = new MYSQL_PROXY(this, dsrc); + this->topology_service = std::make_shared(this->id, ds ? ds->opt_LOG_QUERY : false); + + CONNECTION_PROXY* head = new MYSQL_PROXY(this, dsrc); if (dsrc->opt_ENABLE_FAILURE_DETECTION) { CONNECTION_PROXY* efm_proxy = new EFM_PROXY(this, dsrc); @@ -157,6 +160,12 @@ void DBC::init_proxy_chain(DataSource* dsrc) } } + if (dsrc->opt_ENABLE_CUSTOM_ENDPOINT_MONITORING) { + CONNECTION_PROXY* custom_endpoint_proxy = new CUSTOM_ENDPOINT_PROXY(this, dsrc); + custom_endpoint_proxy->set_next_proxy(head); + head = custom_endpoint_proxy; + } + this->connection_proxy = head; } @@ -165,7 +174,8 @@ DBC::~DBC() if (env) env->remove_dbc(this); - if (connection_proxy) + this->topology_service.reset(); + if (connection_proxy) delete connection_proxy; if (fh) @@ -245,6 +255,7 @@ SQLRETURN SQL_API SQLAllocEnv(SQLHENV *phenv) SQLRETURN SQL_API my_SQLFreeEnv(SQLHENV henv) { MONITOR_THREAD_CONTAINER::release_instance(); + CUSTOM_ENDPOINT_PROXY::release_resources(); ENV *env= (ENV *) henv; delete env; diff --git a/driver/host_info.cc b/driver/host_info.cc index ba9e3beba..cf8d1d138 100644 --- a/driver/host_info.cc +++ b/driver/host_info.cc @@ -29,6 +29,8 @@ #include "host_info.h" +#include "rds_utils.h" + // TODO // the entire HOST_INFO needs to be reviewed based on needed interfaces and other objects like CLUSTER_TOPOLOGY_INFO // most/all of the HOST_INFO potentially could be internal to CLUSTER_TOPOLOGY_INFO and specfic information may be accessed @@ -45,30 +47,27 @@ HOST_INFO::HOST_INFO(const char* host, int port) : HOST_INFO(host, port, UP, false) {} HOST_INFO::HOST_INFO(std::string host, int port, HOST_STATE state, bool is_writer) - : host{ host }, port{ port }, host_state{ state }, is_writer{ is_writer } -{ -} + : host{host}, host_id{RDS_UTILS::get_rds_instance_id(host)}, port{port}, host_state{state}, is_writer{is_writer} {} // would need some checks for nulls HOST_INFO::HOST_INFO(const char* host, int port, HOST_STATE state, bool is_writer) - : host{ host }, port{ port }, host_state{ state }, is_writer{ is_writer } -{ -} + : host{host}, host_id{RDS_UTILS::get_rds_instance_id(host)}, port{port}, host_state{state}, is_writer{is_writer} {} HOST_INFO::~HOST_INFO() {} /** - * Returns the host. - * - * @return the host + * @return the host endpoint */ std::string HOST_INFO::get_host() { return host; } /** - * Returns the port. - * + * @return the host name + */ +std::string HOST_INFO::get_host_id() { return host_id; } + +/** * @return the port */ int HOST_INFO::get_port() { @@ -76,8 +75,6 @@ int HOST_INFO::get_port() { } /** - * Returns a host:port representation of this host. - * * @return the host:port representation of this host */ std::string HOST_INFO::get_host_port_pair() { diff --git a/driver/host_info.h b/driver/host_info.h index e5c64a420..7fd64c1ad 100644 --- a/driver/host_info.h +++ b/driver/host_info.h @@ -32,6 +32,7 @@ #include #include +#include enum HOST_STATE { UP, DOWN }; @@ -49,6 +50,7 @@ class HOST_INFO { int get_port(); std::string get_host(); + std::string get_host_id(); std::string get_host_port_pair(); bool equal_host_port_pair(HOST_INFO& hi); HOST_STATE get_host_state(); @@ -69,10 +71,18 @@ class HOST_INFO { private: const std::string HOST_PORT_SEPARATOR = ":"; const std::string host; + const std::string host_id; const int port = NO_PORT; HOST_STATE host_state; bool is_writer; }; +inline std::ostream& operator<<(std::ostream& str, HOST_INFO v) { + char buf[1024]; + sprintf(buf, "HostSpec[host=%s, port=%d, %s, %s]", v.get_host().c_str(), v.get_port(), + v.is_host_writer() ? "WRITER" : "READER", v.last_updated.c_str()); + return str << std::string(buf); +} + #endif /* __HOSTINFO_H__ */ diff --git a/driver/rds_utils.cc b/driver/rds_utils.cc index 26b6f9157..a49667a0c 100644 --- a/driver/rds_utils.cc +++ b/driver/rds_utils.cc @@ -31,7 +31,7 @@ namespace { const std::regex AURORA_DNS_PATTERN( - R"#((.+)\.(proxy-|cluster-|cluster-ro-|cluster-custom-)?([a-zA-Z0-9]+\.[a-zA-Z0-9\-]+\.rds\.amazonaws\.com))#", + R"#((.+)\.(proxy-|cluster-|cluster-ro-|cluster-custom-)?([a-zA-Z0-9]+\.([a-zA-Z0-9\-]+)\.rds\.amazonaws\.com))#", std::regex_constants::icase); const std::regex AURORA_PROXY_DNS_PATTERN(R"#((.+)\.(proxy-)+([a-zA-Z0-9]+\.[a-zA-Z0-9\-]+\.rds\.amazonaws\.com))#", std::regex_constants::icase); @@ -133,16 +133,64 @@ std::string RDS_UTILS::get_rds_cluster_host_url(std::string host) { return f(AURORA_CHINA_CLUSTER_PATTERN); } +std::string RDS_UTILS::get_rds_cluster_id(std::string host) { + auto f = [host](const std::regex pattern) { + std::smatch m; + if (std::regex_search(host, m, pattern) && m.size() > 1 && !m.str(2).empty()) { + return m.str(1); + } + return std::string(); + }; + + auto result = f(AURORA_DNS_PATTERN); + if (!result.empty()) { + return result; + } + + return f(AURORA_CHINA_DNS_PATTERN); +} + +std::string RDS_UTILS::get_rds_instance_id(std::string host) { + auto f = [host](const std::regex pattern) { + std::smatch m; + if (std::regex_search(host, m, pattern) && m.size() > 1 && m.str(2).empty()) { + return m.str(1); + } + return std::string(); + }; + + auto result = f(AURORA_DNS_PATTERN); + if (!result.empty()) { + return result; + } + + return f(AURORA_CHINA_DNS_PATTERN); +} std::string RDS_UTILS::get_rds_instance_host_pattern(std::string host) { auto f = [host](const std::regex pattern) { std::smatch m; - if (std::regex_search(host, m, pattern) && m.size() > 3) { - if (!m.str(3).empty()) { - std::string result("?."); - result.append(m.str(3)); + if (std::regex_search(host, m, pattern) && m.size() > 4 && !m.str(3).empty()) { + std::string result("?."); + result.append(m.str(3)); - return result; - } + return result; + } + return std::string(); + }; + + auto result = f(AURORA_DNS_PATTERN); + if (!result.empty()) { + return result; + } + + return f(AURORA_CHINA_DNS_PATTERN); +} + +std::string RDS_UTILS::get_rds_region(std::string host) { + auto f = [host](const std::regex pattern) { + std::smatch m; + if (std::regex_search(host, m, pattern) && m.size() > 4 && !m.str(4).empty()) { + return m.str(4); } return std::string(); }; diff --git a/driver/rds_utils.h b/driver/rds_utils.h index 1475129a2..4c4b1648c 100644 --- a/driver/rds_utils.h +++ b/driver/rds_utils.h @@ -45,7 +45,10 @@ class RDS_UTILS { static bool is_ipv6(std::string host); static std::string get_rds_cluster_host_url(std::string host); + static std::string get_rds_cluster_id(std::string host); static std::string get_rds_instance_host_pattern(std::string host); + static std::string get_rds_instance_id(std::string host); + static std::string get_rds_region(std::string host); }; #endif diff --git a/driver/sliding_expiration_cache.cc b/driver/sliding_expiration_cache.cc index f1a1169f0..af56f0a1c 100644 --- a/driver/sliding_expiration_cache.cc +++ b/driver/sliding_expiration_cache.cc @@ -33,6 +33,8 @@ #include #include +#include "custom_endpoint_monitor.h" + template void SLIDING_EXPIRATION_CACHE::remove_and_dispose(K key) { if (this->cache.count(key)) { @@ -64,11 +66,11 @@ void SLIDING_EXPIRATION_CACHE::clean_up() { } template -V SLIDING_EXPIRATION_CACHE::compute_if_absent(K key, std::function mapping_function, +V SLIDING_EXPIRATION_CACHE::compute_if_absent(K key, std::function mapping_function, long long item_expiration_nanos) { this->clean_up(); - auto cache_item = std::make_shared(mapping_function(key), - std::chrono::steady_clock::now() + std::chrono::nanoseconds(item_expiration_nanos)); + V item = mapping_function(key); + auto cache_item = std::make_shared(item, std::chrono::steady_clock::now() + std::chrono::nanoseconds(item_expiration_nanos)); this->cache.emplace(key, cache_item); return cache_item->with_extend_expiration(item_expiration_nanos)->item; } @@ -135,3 +137,6 @@ void SLIDING_EXPIRATION_CACHE::set_clean_up_interval_nanos(long long clean } template class SLIDING_EXPIRATION_CACHE; +template class SLIDING_EXPIRATION_CACHE>; +template class SHOULD_DISPOSE_FUNC>; +template class ITEM_DISPOSAL_FUNC>; diff --git a/driver/sliding_expiration_cache.h b/driver/sliding_expiration_cache.h index 642226b6f..9922a6948 100644 --- a/driver/sliding_expiration_cache.h +++ b/driver/sliding_expiration_cache.h @@ -39,13 +39,15 @@ template class SHOULD_DISPOSE_FUNC { public: - virtual bool should_dispose(T item); + virtual ~SHOULD_DISPOSE_FUNC() = default; + virtual bool should_dispose(T item) { return true; }; }; template class ITEM_DISPOSAL_FUNC { public: - virtual void dispose(T item); + virtual ~ITEM_DISPOSAL_FUNC() = default; + virtual void dispose(T item) {/* Do nothing. */}; }; template @@ -55,7 +57,7 @@ class SLIDING_EXPIRATION_CACHE { public: CACHE_ITEM() = default; CACHE_ITEM(V item, std::chrono::steady_clock::time_point expiration_time) - : item(item), expiration_time(expiration_time){}; + : item(item), expiration_time(expiration_time){}; ~CACHE_ITEM() = default; V item; @@ -64,7 +66,7 @@ class SLIDING_EXPIRATION_CACHE { return this; } - bool should_clean_up(SHOULD_DISPOSE_FUNC* should_dispose_func) { + bool should_clean_up(std::shared_ptr> should_dispose_func) { if (should_dispose_func != nullptr) { return std::chrono::steady_clock::now() > this->expiration_time && should_dispose_func->should_dispose(this->item); @@ -82,15 +84,16 @@ class SLIDING_EXPIRATION_CACHE { this->item_disposal_func = nullptr; } - SLIDING_EXPIRATION_CACHE(SHOULD_DISPOSE_FUNC* should_dispose_func, ITEM_DISPOSAL_FUNC* item_disposal_func) - : should_dispose_func(should_dispose_func), item_disposal_func(item_disposal_func){}; - SLIDING_EXPIRATION_CACHE(SHOULD_DISPOSE_FUNC* should_dispose_func, ITEM_DISPOSAL_FUNC* item_disposal_func, - long long clean_up_interval_nanos) - : clean_up_interval_nanos(clean_up_interval_nanos), - should_dispose_func(should_dispose_func), - item_disposal_func(item_disposal_func){}; + SLIDING_EXPIRATION_CACHE(std::shared_ptr> should_dispose_func, + std::shared_ptr> item_disposal_func) + : should_dispose_func(should_dispose_func), item_disposal_func(item_disposal_func){}; + SLIDING_EXPIRATION_CACHE(std::shared_ptr> should_dispose_func, + std::shared_ptr> item_disposal_func, long long clean_up_interval_nanos) + : clean_up_interval_nanos(clean_up_interval_nanos), + should_dispose_func(std::move(should_dispose_func)), + item_disposal_func(std::move(item_disposal_func)){}; - V compute_if_absent(K key, std::function mapping_function, long long item_expiration_nanos); + V compute_if_absent(K key, std::function mapping_function, long long item_expiration_nanos); V put(K key, V value, long long item_expiration_nanos); V get(K key, long long item_expiration_nanos, V default_value); @@ -121,17 +124,16 @@ class SLIDING_EXPIRATION_CACHE { std::unordered_map> cache; long long clean_up_interval_nanos = 6000000000; // 1 minutes std::atomic clean_up_time_nanos; - SHOULD_DISPOSE_FUNC* should_dispose_func; - ITEM_DISPOSAL_FUNC* item_disposal_func; + std::shared_ptr> should_dispose_func; + std::shared_ptr> item_disposal_func; void remove_and_dispose(K key); void remove_if_expired(K key) { - std::vector items; if (this->cache.count(key)) { std::shared_ptr cache_item = this->cache[key]; if (cache_item != nullptr && cache_item->should_clean_up(this->should_dispose_func)) { if (item_disposal_func != nullptr) { - item_disposal_func->dispose(items[0]); + item_disposal_func->dispose(cache_item->item); } this->cache.erase(key); } diff --git a/driver/sliding_expiration_cache_with_clean_up_thread.cc b/driver/sliding_expiration_cache_with_clean_up_thread.cc index 815db1e76..b034a5df5 100644 --- a/driver/sliding_expiration_cache_with_clean_up_thread.cc +++ b/driver/sliding_expiration_cache_with_clean_up_thread.cc @@ -31,13 +31,14 @@ #include +#include "custom_endpoint_monitor.h" + template void SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::init_clean_up_thread() { if (!this->is_initialized) { std::unique_lock lock(mutex_); if (!this->is_initialized) { this->clean_up_thread_pool.resize(this->clean_up_thread_pool.size() + 1); - this->clean_up_thread_pool.push([=](int id) { while (!should_stop) { const std::chrono::nanoseconds clean_up_interval = std::chrono::nanoseconds(this->clean_up_interval_nanos); @@ -58,41 +59,23 @@ void SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::init_clean_up_thread() } } -template -SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD() { - this->init_clean_up_thread(); -} template -SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD( - SHOULD_DISPOSE_FUNC* should_dispose_func, ITEM_DISPOSAL_FUNC* item_disposal_func) - : SLIDING_EXPIRATION_CACHE(should_dispose_func, item_disposal_func) { - this->init_clean_up_thread(); +bool SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::empty() { + return this->cache.empty(); } -template -SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD( - SHOULD_DISPOSE_FUNC* should_dispose_func, ITEM_DISPOSAL_FUNC* item_disposal_func, - long long clean_up_interval_nanos) - : SLIDING_EXPIRATION_CACHE(should_dispose_func, item_disposal_func, clean_up_interval_nanos) { - this->init_clean_up_thread(); -} - -#ifdef UNIT_TEST_BUILD -template -SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD( - long long clean_up_interval_nanos) { - this->clean_up_interval_nanos = clean_up_interval_nanos; - this->init_clean_up_thread(); -} -#endif - template void SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::release_resources() { - this->should_stop = true; - this->clean_up_thread_pool.stop(true); - this->clean_up_thread_pool.resize(0); - this->is_initialized = false; + std::unique_lock lock(mutex_); + { + this->clear(); + this->should_stop = true; + this->clean_up_thread_pool.stop(true); + this->clean_up_thread_pool.resize(0); + this->is_initialized = false; + } } template class SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD; +template class SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD>; diff --git a/driver/sliding_expiration_cache_with_clean_up_thread.h b/driver/sliding_expiration_cache_with_clean_up_thread.h index ebc5f0c46..04a0ee2e0 100644 --- a/driver/sliding_expiration_cache_with_clean_up_thread.h +++ b/driver/sliding_expiration_cache_with_clean_up_thread.h @@ -38,32 +38,35 @@ template class SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD : public SLIDING_EXPIRATION_CACHE { public: - SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD(); - SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD(SHOULD_DISPOSE_FUNC* should_dispose_func, - ITEM_DISPOSAL_FUNC* item_disposal_func); - SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD(SHOULD_DISPOSE_FUNC* should_dispose_func, - ITEM_DISPOSAL_FUNC* item_disposal_func, - long long clean_up_interval_nanos); + SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD() = default; + SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD(std::shared_ptr> should_dispose_func, + std::shared_ptr> item_disposal_func) + : SLIDING_EXPIRATION_CACHE(std::move(should_dispose_func), std::move(item_disposal_func)){}; + SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD(std::shared_ptr> should_dispose_func, + std::shared_ptr> item_disposal_func, + long long clean_up_interval_nanos) + : SLIDING_EXPIRATION_CACHE(std::move(should_dispose_func), std::move(item_disposal_func), + clean_up_interval_nanos){}; ~SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD() = default; - /** * Stop clean up thread. Should be called at the end of the cache's lifetime. */ void release_resources(); + bool empty(); #ifdef UNIT_TEST_BUILD - SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD(long long clean_up_interval_nanos); + SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD(long long clean_up_interval_nanos) { + this->clean_up_interval_nanos = clean_up_interval_nanos; + }; #endif + void init_clean_up_thread(); protected: bool is_initialized = false; bool should_stop = false; std::mutex mutex_; ctpl::thread_pool clean_up_thread_pool; - - private: - void init_clean_up_thread(); }; #endif diff --git a/driver/topology_service.cc b/driver/topology_service.cc index 265dee136..8ec39f40b 100644 --- a/driver/topology_service.cc +++ b/driver/topology_service.cc @@ -29,6 +29,7 @@ #include "cluster_aware_metrics_container.h" #include "topology_service.h" +#include TOPOLOGY_SERVICE::TOPOLOGY_SERVICE(unsigned long dbc_id, bool enable_logging) : dbc_id{dbc_id}, @@ -164,20 +165,15 @@ std::shared_ptr TOPOLOGY_SERVICE::get_cached_topology() { return get_from_cache(); } -//TODO consider the return value -//Note to determine whether or not force_update succeeded one would compare -// CLUSTER_TOPOLOGY_INFO->time_last_updated() prior and after the call if non-null information was given prior. -std::shared_ptr TOPOLOGY_SERVICE::get_topology(CONNECTION_PROXY* connection, bool force_update) -{ - //TODO reconsider using this cache. It appears that we only store information for the current cluster Id. +std::shared_ptr TOPOLOGY_SERVICE::get_topology(CONNECTION_PROXY* connection, bool force_update) { + // TODO reconsider using this cache. It appears that we only store information for the current cluster Id. // therefore instead of a map we can just keep CLUSTER_TOPOLOGY_INFO* topology_info member variable. auto cached_topology = get_from_cache(); if (!cached_topology || force_update || refresh_needed(cached_topology->time_last_updated())) { - auto latest_topology = query_for_topology(connection); - if (latest_topology) { + if (auto latest_topology = query_for_topology(connection)) { put_to_cache(latest_topology); return latest_topology; } @@ -186,6 +182,39 @@ std::shared_ptr TOPOLOGY_SERVICE::get_topology(CONNECTION return cached_topology; } +std::shared_ptr TOPOLOGY_SERVICE::get_filtered_topology(std::shared_ptr topology) { + if (!this->allowed_and_blocked_hosts) { + return topology; + } + + std::set allowed_list = this->allowed_and_blocked_hosts->get_allowed_host_ids(); + std::set blocked_list = this->allowed_and_blocked_hosts->get_blocked_host_ids(); + + const std::shared_ptr filtered_topology = std::make_shared(); + for (const auto& host : topology->get_instances()) { + const auto host_id = host->get_host_id(); + if (allowed_list.find(host_id) != allowed_list.end() && blocked_list.find(host_id) == blocked_list.end()) { + filtered_topology->add_host(host); + } + } + + return filtered_topology; +} + +std::string TOPOLOGY_SERVICE::log_topology(const std::shared_ptr topology) { + std::stringstream topology_str; + topology_str << "[TOPOLOGY_SERVICE] Topology: "; + if (topology->total_hosts() == 0) { + topology_str << ""; + return topology_str.str(); + } + for (const auto& host : topology->get_instances()) { + topology_str << "\n\t" << *host; + } + + return topology_str.str(); +} + // TODO consider thread safety and usage of pointers std::shared_ptr TOPOLOGY_SERVICE::get_from_cache() { if (topology_cache.empty()) { diff --git a/driver/topology_service.h b/driver/topology_service.h index 60f4b6d4d..bb1c1378a 100644 --- a/driver/topology_service.h +++ b/driver/topology_service.h @@ -24,7 +24,7 @@ // See the GNU General Public License, version 2.0, for more details. // // You should have received a copy of the GNU General Public License -// along with this program. If not, see +// along with this program. If not, see // http://www.gnu.org/licenses/gpl-2.0.html. #ifndef __TOPOLOGYSERVICE_H__ @@ -36,16 +36,18 @@ #include "cluster_topology_info.h" #include "connection_proxy.h" -#include -#include #include #include +#include +#include +#include "allowed_and_blocked_hosts.h" // TODO - consider - do we really need miliseconds for refresh? - the default numbers here are already 30 seconds.000; #define DEFAULT_REFRESH_RATE_IN_MILLISECONDS 30000 #define WRITER_SESSION_ID "MASTER_SESSION_ID" -#define RETRIEVE_TOPOLOGY_SQL "SELECT SERVER_ID, SESSION_ID, LAST_UPDATE_TIMESTAMP, REPLICA_LAG_IN_MILLISECONDS \ +#define RETRIEVE_TOPOLOGY_SQL \ + "SELECT SERVER_ID, SESSION_ID, LAST_UPDATE_TIMESTAMP, REPLICA_LAG_IN_MILLISECONDS \ FROM information_schema.replica_host_status \ WHERE time_to_sec(timediff(now(), LAST_UPDATE_TIMESTAMP)) <= 300 \ ORDER BY LAST_UPDATE_TIMESTAMP DESC" @@ -54,69 +56,75 @@ static std::map> topology_ca static std::mutex topology_cache_mutex; class TOPOLOGY_SERVICE { -public: - TOPOLOGY_SERVICE(unsigned long dbc_id, bool enable_logging = false); - TOPOLOGY_SERVICE(const TOPOLOGY_SERVICE&); - virtual ~TOPOLOGY_SERVICE(); - - virtual void set_cluster_id(std::string cluster_id); - virtual void set_cluster_instance_template(std::shared_ptr host_template); //is this equivalent to setcluster_instance_host - - virtual std::shared_ptr get_topology( - CONNECTION_PROXY* connection, bool force_update = false); - std::shared_ptr get_cached_topology(); - - std::shared_ptr get_last_used_reader(); - void set_last_used_reader(std::shared_ptr reader); - std::set get_down_hosts(); - virtual void mark_host_down(std::shared_ptr host); - virtual void mark_host_up(std::shared_ptr host); - void set_refresh_rate(int refresh_rate); - void set_gather_metric(bool can_gather); - void clear_all(); - void clear(); - - // Property Keys - const std::string SESSION_ID = "TOPOLOGY_SERVICE_SESSION_ID"; - const std::string LAST_UPDATED = "TOPOLOGY_SERVICE_LAST_UPDATE_TIMESTAMP"; - const std::string REPLICA_LAG = "TOPOLOGY_SERVICE_REPLICA_LAG_IN_MILLISECONDS"; - const std::string INSTANCE_NAME = "TOPOLOGY_SERVICE_SERVER_ID"; - -private: - const int DEFAULT_CACHE_EXPIRE_MS = 5 * 60 * 1000; // 5 min - - const std::string GET_INSTANCE_NAME_SQL = "SELECT @@aurora_server_id"; - const std::string GET_INSTANCE_NAME_COL = "@@aurora_server_id"; - - const std::string FIELD_SERVER_ID = "SERVER_ID"; - const std::string FIELD_SESSION_ID = "SESSION_ID"; - const std::string FIELD_LAST_UPDATED = "LAST_UPDATE_TIMESTAMP"; - const std::string FIELD_REPLICA_LAG = "REPLICA_LAG_IN_MILLISECONDS"; - - std::shared_ptr logger = nullptr; - unsigned long dbc_id = 0; - -protected: - const int NO_CONNECTION_INDEX = -1; - int refresh_rate_in_ms; - - std::string cluster_id; - std::shared_ptr cluster_instance_host; - - std::shared_ptr metrics_container; - - bool refresh_needed(std::time_t last_updated); - std::shared_ptr query_for_topology(CONNECTION_PROXY* connection); - std::shared_ptr create_host(MYSQL_ROW& row); - std::string get_host_endpoint(const char* node_name); - static bool does_instance_exist( - std::map>& instances, - std::shared_ptr host_info); - - std::shared_ptr get_from_cache(); - void put_to_cache(std::shared_ptr topology_info); - - MYSQL_RES* try_execute_query(CONNECTION_PROXY* connection_proxy, const char* query); + public: + TOPOLOGY_SERVICE(unsigned long dbc_id, bool enable_logging = false); + TOPOLOGY_SERVICE(const TOPOLOGY_SERVICE&); + virtual ~TOPOLOGY_SERVICE(); + + virtual void set_cluster_id(std::string cluster_id); + virtual void set_cluster_instance_template( + std::shared_ptr host_template); // is this equivalent to set_cluster_instance_host + + virtual std::shared_ptr get_topology(CONNECTION_PROXY* connection, bool force_update = false); + virtual std::shared_ptr get_filtered_topology(std::shared_ptr topology); + virtual std::string log_topology(const std::shared_ptr topology); + std::shared_ptr get_cached_topology(); + + std::shared_ptr get_last_used_reader(); + void set_last_used_reader(std::shared_ptr reader); + std::set get_down_hosts(); + virtual void mark_host_down(std::shared_ptr host); + virtual void mark_host_up(std::shared_ptr host); + void set_refresh_rate(int refresh_rate); + void set_gather_metric(bool can_gather); + void clear_all(); + void clear(); + void set_allowed_and_blocked_hosts(std::shared_ptr hosts) { + this->allowed_and_blocked_hosts = hosts; + }; + + // Property Keys + const std::string SESSION_ID = "TOPOLOGY_SERVICE_SESSION_ID"; + const std::string LAST_UPDATED = "TOPOLOGY_SERVICE_LAST_UPDATE_TIMESTAMP"; + const std::string REPLICA_LAG = "TOPOLOGY_SERVICE_REPLICA_LAG_IN_MILLISECONDS"; + const std::string INSTANCE_NAME = "TOPOLOGY_SERVICE_SERVER_ID"; + + private: + const int DEFAULT_CACHE_EXPIRE_MS = 5 * 60 * 1000; // 5 min + + const std::string GET_INSTANCE_NAME_SQL = "SELECT @@aurora_server_id"; + const std::string GET_INSTANCE_NAME_COL = "@@aurora_server_id"; + + const std::string FIELD_SERVER_ID = "SERVER_ID"; + const std::string FIELD_SESSION_ID = "SESSION_ID"; + const std::string FIELD_LAST_UPDATED = "LAST_UPDATE_TIMESTAMP"; + const std::string FIELD_REPLICA_LAG = "REPLICA_LAG_IN_MILLISECONDS"; + + std::shared_ptr logger = nullptr; + unsigned long dbc_id = 0; + + protected: + const int NO_CONNECTION_INDEX = -1; + int refresh_rate_in_ms; + + std::string cluster_id; + std::shared_ptr cluster_instance_host; + + std::shared_ptr metrics_container; + + std::shared_ptr allowed_and_blocked_hosts; + + bool refresh_needed(std::time_t last_updated); + std::shared_ptr query_for_topology(CONNECTION_PROXY* connection); + std::shared_ptr create_host(MYSQL_ROW& row); + std::string get_host_endpoint(const char* node_name); + static bool does_instance_exist(std::map>& instances, + std::shared_ptr host_info); + + std::shared_ptr get_from_cache(); + void put_to_cache(std::shared_ptr topology_info); + + MYSQL_RES* try_execute_query(CONNECTION_PROXY* connection_proxy, const char* query); }; #endif /* __TOPOLOGYSERVICE_H__ */ diff --git a/integration/CMakeLists.txt b/integration/CMakeLists.txt index fe716112a..055b59d81 100644 --- a/integration/CMakeLists.txt +++ b/integration/CMakeLists.txt @@ -99,12 +99,12 @@ set(TEST_SOURCES base_failover_integration_test.cc connection_string_builder_test.cc) set(INTEGRATION_TESTS + custom_endpoint_integration_test.cc iam_authentication_integration_test.cc secrets_manager_integration_test.cc network_failover_integration_test.cc failover_integration_test.cc ) - if(NOT ENABLE_PERFORMANCE_TESTS) set(TEST_SOURCES ${TEST_SOURCES} ${INTEGRATION_TESTS}) else() diff --git a/integration/base_failover_integration_test.cc b/integration/base_failover_integration_test.cc index d910703cd..5172900f0 100644 --- a/integration/base_failover_integration_test.cc +++ b/integration/base_failover_integration_test.cc @@ -37,7 +37,9 @@ #include #include #include +#include #include +#include #include #include @@ -49,6 +51,7 @@ #include #include #include +#include #include #include @@ -269,6 +272,13 @@ class BaseFailoverIntegrationTest : public testing::Test { } } + static Aws::RDS::Model::DBClusterEndpoint get_custom_endpoint_info(const Aws::RDS::RDSClient& client, const std::string& endpoint_id) { + Aws::RDS::Model::DescribeDBClusterEndpointsRequest request; + request.SetDBClusterEndpointIdentifier(endpoint_id); + const auto response = client.DescribeDBClusterEndpoints(request); + return response.GetResult().GetDBClusterEndpoints()[0]; + } + static Aws::RDS::Model::DBClusterMember get_DB_cluster_writer_instance(const Aws::RDS::RDSClient& client, const Aws::String& cluster_id) { Aws::RDS::Model::DBClusterMember instance; const Aws::RDS::Model::DBCluster cluster = get_DB_cluster(client, cluster_id); diff --git a/integration/connection_string_builder.h b/integration/connection_string_builder.h index 96d08ce37..a30334bcb 100644 --- a/integration/connection_string_builder.h +++ b/integration/connection_string_builder.h @@ -166,6 +166,38 @@ class ConnectionStringBuilder { return *this; } + ConnectionStringBuilder& withEnableCustomEndpointMonitoring(const bool& enable_custom_endpoint_monitoring) { + length += + sprintf(conn_in + length, "ENABLE_CUSTOM_ENDPOINT_MONITORING=%d;", enable_custom_endpoint_monitoring ? 1 : 0); + return *this; + } + ConnectionStringBuilder& withCustomEndpointRegion(const std::string& region) { + length += sprintf(conn_in + length, "CUSTOM_ENDPOINT_REGION=%s;", region.c_str()); + return *this; + } + + ConnectionStringBuilder& withShouldWaitForInfo(const bool& should_wait_for_info) { + length += sprintf(conn_in + length, "WAIT_FOR_CUSTOM_ENDPOINT_INFO=%d;", should_wait_for_info ? 1 : 0); + return *this; + } + + ConnectionStringBuilder& withCustomEndpointInfoRefreshRateMs(const long& custom_endpoint_info_refresh_rate_ms) { + length += + sprintf(conn_in + length, "CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS=%ld;", custom_endpoint_info_refresh_rate_ms); + return *this; + } + + ConnectionStringBuilder& withWaitOnCachedInfoDurationMs(const long& wait_on_cached_info_duration_ms) { + length += + sprintf(conn_in + length, "WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS=%ld;", wait_on_cached_info_duration_ms); + return *this; + } + + ConnectionStringBuilder& withIdleMonitorExpirationMs(const long& idle_monitor_expiration_ms) { + length += sprintf(conn_in + length, "CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS=%ld;", idle_monitor_expiration_ms); + return *this; + } + std::string getString() const { return conn_in; } private: diff --git a/integration/custom_endpoint_integration_test.cc b/integration/custom_endpoint_integration_test.cc new file mode 100644 index 000000000..f2202cfc2 --- /dev/null +++ b/integration/custom_endpoint_integration_test.cc @@ -0,0 +1,172 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include +#include +#include "base_failover_integration_test.cc" + +class CustomEndpointIntegrationTest : public BaseFailoverIntegrationTest { + protected: + std::string ACCESS_KEY = std::getenv("AWS_ACCESS_KEY_ID"); + std::string SECRET_ACCESS_KEY = std::getenv("AWS_SECRET_ACCESS_KEY"); + std::string SESSION_TOKEN = std::getenv("AWS_SESSION_TOKEN"); + std::string RDS_ENDPOINT = std::getenv("RDS_ENDPOINT"); + std::string RDS_REGION = std::getenv("RDS_REGION"); + std::string ENDPOINT_ID = + "test-endpoint-1-" + std::to_string(std::chrono::steady_clock::now().time_since_epoch().count()); + std::string region = "us-east-2"; + Aws::RDS::RDSClientConfiguration client_config; + Aws::RDS::RDSClient rds_client; + SQLHENV env = nullptr; + SQLHDBC dbc = nullptr; + + bool is_endpoint_created = false; + + static void SetUpTestSuite() { Aws::InitAPI(options); } + + static void TearDownTestSuite() { Aws::ShutdownAPI(options); } + void SetUp() override { + if (SQLAllocHandle(SQL_HANDLE_ENV, nullptr, &env) != SQL_SUCCESS) { + throw std::runtime_error("Failed to allocate handles for integration tests."); + } + SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, reinterpret_cast(SQL_OV_ODBC3), 0); + if (SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc) != SQL_SUCCESS) { + throw std::runtime_error("Failed to allocate handles for integration tests."); + } + + Aws::Auth::AWSCredentials credentials = + SESSION_TOKEN.empty() ? Aws::Auth::AWSCredentials(Aws::String(ACCESS_KEY), Aws::String(SECRET_ACCESS_KEY)) + : Aws::Auth::AWSCredentials(Aws::String(ACCESS_KEY), Aws::String(SECRET_ACCESS_KEY), + Aws::String(SESSION_TOKEN)); + if (!RDS_REGION.empty()) { + region = RDS_REGION; + } + client_config.region = region; + if (!RDS_ENDPOINT.empty()) { + client_config.endpointOverride = RDS_ENDPOINT; + } + rds_client = Aws::RDS::RDSClient(credentials, client_config); + + cluster_instances = retrieve_topology_via_SDK(rds_client, cluster_id); + writer_id = get_writer_id(cluster_instances); + readers = get_readers(cluster_instances); + target_writer_id = get_random_DB_cluster_reader_instance_id(readers); + + if (!is_endpoint_created) { + const std::vector instance{writer_id}; + create_custom_endpoint(cluster_id, instance); + wait_until_endpoint_available(cluster_id); + } + } + + void create_custom_endpoint(const std::string& cluster_id, const std::vector& writer) const { + Aws::RDS::Model::CreateDBClusterEndpointRequest rds_req; + rds_req.SetDBClusterEndpointIdentifier(ENDPOINT_ID); + rds_req.SetDBClusterIdentifier(cluster_id); + rds_req.SetEndpointType("ANY"); + rds_req.SetStaticMembers(writer); + rds_client.CreateDBClusterEndpoint(rds_req); + } + + void delete_custom_endpoint() { + Aws::RDS::Model::DeleteDBClusterEndpointRequest rds_req; + rds_req.SetDBClusterEndpointIdentifier(ENDPOINT_ID); + rds_client.DeleteDBClusterEndpoint(rds_req); + } + + /** + * Wait up to 5 minutes for the new custom endpoint to become unavailable. + */ + void wait_until_endpoint_available(const std::string& cluster_id) const { + const std::chrono::steady_clock::time_point end_time = std::chrono::steady_clock::now() + std::chrono::minutes(5); + bool is_available = false; + + Aws::String status = get_DB_cluster(rds_client, cluster_id).GetStatus(); + + while (std::chrono::steady_clock::now() < end_time) { + const auto endpoint_info = get_custom_endpoint_info(rds_client, ENDPOINT_ID); + is_available = endpoint_info.GetStatus() == "available"; + if (is_available) { + break; + } + std::this_thread::sleep_for(std::chrono::seconds(3)); + } + + if (!is_available) { + throw std::runtime_error( + "The test setup step timed out while waiting for the custom endpoint to become available: " + ENDPOINT_ID); + } + } + + void TearDown() override { + delete_custom_endpoint(); + + if (nullptr != dbc) { + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + } + if (nullptr != env) { + SQLFreeHandle(SQL_HANDLE_ENV, env); + } + } +}; + +TEST_F(CustomEndpointIntegrationTest, test_CustomeEndpointFailover) { + const auto endpoint_info = get_custom_endpoint_info(rds_client, ENDPOINT_ID); + + connection_string = ConnectionStringBuilder(dsn, endpoint_info.GetEndpoint(), MYSQL_PORT) + .withLogQuery(true) + .withEnableFailureDetection(true) + .withUID(user) + .withPWD(pwd) + .withDatabase(db) + .withFailoverMode("reader or writer") + .withEnableCustomEndpointMonitoring(true) + .withCustomEndpointRegion(region) + .getString(); + SQLCHAR conn_out[4096] = "\0"; + SQLSMALLINT len; + + EXPECT_EQ(SQL_SUCCESS, SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS, conn_out, + MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT)); + + const std::vector& endpoint_members = endpoint_info.GetStaticMembers(); + const std::string current_connection_id = query_instance_id(dbc); + + EXPECT_NE(std::find(endpoint_members.begin(), endpoint_members.end(), current_connection_id), endpoint_members.end()); + + failover_cluster_and_wait_until_writer_changed( + rds_client, cluster_id, writer_id, current_connection_id == writer_id ? target_writer_id : current_connection_id); + + assert_query_failed(dbc, SERVER_ID_QUERY, ERROR_COMM_LINK_CHANGED); + + const std::string new_connection_id = query_instance_id(dbc); + + EXPECT_NE(std::find(endpoint_members.begin(), endpoint_members.end(), new_connection_id), endpoint_members.end()); + EXPECT_EQ(SQL_SUCCESS, SQLDisconnect(dbc)); +} diff --git a/scripts/build_aws_sdk_unix.sh b/scripts/build_aws_sdk_unix.sh index 4e73c5fdd..2a7388b0b 100755 --- a/scripts/build_aws_sdk_unix.sh +++ b/scripts/build_aws_sdk_unix.sh @@ -40,7 +40,7 @@ AWS_INSTALL_DIR=$AWS_SRC_DIR/../install mkdir -p $AWS_SRC_DIR $AWS_BUILD_DIR $AWS_INSTALL_DIR -git clone --recurse-submodules -b "1.11.394" "https://github.com/aws/aws-sdk-cpp.git" $AWS_SRC_DIR +git clone --recurse-submodules -b "1.11.488" "https://github.com/aws/aws-sdk-cpp.git" $AWS_SRC_DIR cmake -S $AWS_SRC_DIR -B $AWS_BUILD_DIR -DCMAKE_INSTALL_PREFIX="${AWS_INSTALL_DIR}" -DCMAKE_BUILD_TYPE="${CONFIGURATION}" -DBUILD_ONLY="rds;secretsmanager;sts" -DENABLE_TESTING="OFF" -DBUILD_SHARED_LIBS="ON" -DCPP_STANDARD="14" cd $AWS_BUILD_DIR diff --git a/scripts/build_aws_sdk_win.ps1 b/scripts/build_aws_sdk_win.ps1 index 2ec5c4935..d887a8cab 100644 --- a/scripts/build_aws_sdk_win.ps1 +++ b/scripts/build_aws_sdk_win.ps1 @@ -44,7 +44,7 @@ Write-Host $args # Make AWS SDK source directory New-Item -Path $SRC_DIR -ItemType Directory -Force | Out-Null # Clone the AWS SDK CPP repo -git clone --recurse-submodules -b "1.11.394" "https://github.com/aws/aws-sdk-cpp.git" $SRC_DIR +git clone --recurse-submodules -b "1.11.488" "https://github.com/aws/aws-sdk-cpp.git" $SRC_DIR # Make and move to build directory New-Item -Path $BUILD_DIR -ItemType Directory -Force | Out-Null diff --git a/setupgui/callbacks.cc b/setupgui/callbacks.cc index 53177485b..80776f34a 100644 --- a/setupgui/callbacks.cc +++ b/setupgui/callbacks.cc @@ -346,7 +346,15 @@ void syncTabsData(HWND hwnd, DataSource *params) GET_UNSIGNED_TAB(FED_AUTH_TAB, CLIENT_SOCKET_TIMEOUT); GET_BOOL_TAB(FED_AUTH_TAB, ENABLE_SSL); - /* 5 - Failover */ + /* 5 - Custom Endpoint */ + GET_BOOL_TAB(CUSTOM_ENDPOINT_TAB, ENABLE_CUSTOM_ENDPOINT_MONITORING); + GET_BOOL_TAB(CUSTOM_ENDPOINT_TAB, WAIT_FOR_CUSTOM_ENDPOINT_INFO); + GET_UNSIGNED_TAB(CUSTOM_ENDPOINT_TAB, CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS); + GET_UNSIGNED_TAB(CUSTOM_ENDPOINT_TAB, CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS); + GET_UNSIGNED_TAB(CUSTOM_ENDPOINT_TAB, WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS); + GET_STRING_TAB(CUSTOM_ENDPOINT_TAB, CUSTOM_ENDPOINT_REGION); + + /* 6 - Failover */ GET_BOOL_TAB(FAILOVER_TAB, ENABLE_CLUSTER_FAILOVER); GET_COMBO_TAB(FAILOVER_TAB, FAILOVER_MODE); GET_BOOL_TAB(FAILOVER_TAB, GATHER_PERF_METRICS); @@ -365,7 +373,7 @@ void syncTabsData(HWND hwnd, DataSource *params) GET_UNSIGNED_TAB(FAILOVER_TAB, CONNECT_TIMEOUT); GET_UNSIGNED_TAB(FAILOVER_TAB, NETWORK_TIMEOUT); - /* 6 - Monitoring */ + /* 7 - Monitoring */ GET_BOOL_TAB(MONITORING_TAB, ENABLE_FAILURE_DETECTION); if (READ_BOOL_TAB(MONITORING_TAB, ENABLE_FAILURE_DETECTION)) { @@ -376,7 +384,7 @@ void syncTabsData(HWND hwnd, DataSource *params) GET_UNSIGNED_TAB(MONITORING_TAB, MONITOR_DISPOSAL_TIME); } - /* 7 - Metadata*/ + /* 8 - Metadata*/ GET_BOOL_TAB(METADATA_TAB, NO_BIGINT); GET_BOOL_TAB(METADATA_TAB, NO_BINARY_RESULT); GET_BOOL_TAB(METADATA_TAB, FULL_COLUMN_NAMES); @@ -384,7 +392,7 @@ void syncTabsData(HWND hwnd, DataSource *params) GET_BOOL_TAB(METADATA_TAB, NO_SCHEMA); GET_BOOL_TAB(METADATA_TAB, COLUMN_SIZE_S32); - /* 8 - Cursors/Results */ + /* 9 - Cursors/Results */ GET_BOOL_TAB(CURSORS_TAB, FOUND_ROWS); GET_BOOL_TAB(CURSORS_TAB, AUTO_IS_NULL); GET_BOOL_TAB(CURSORS_TAB, DYNAMIC_CURSOR); @@ -402,10 +410,10 @@ void syncTabsData(HWND hwnd, DataSource *params) { params->opt_PREFETCH = 0; } - /* 9 - debug*/ + /* 10 - debug*/ GET_BOOL_TAB(DEBUG_TAB,LOG_QUERY); - /* 10 - ssl related */ + /* 11 - ssl related */ GET_STRING_TAB(SSL_TAB, SSL_KEY); GET_STRING_TAB(SSL_TAB, SSL_CERT); GET_STRING_TAB(SSL_TAB, SSL_CA); @@ -420,7 +428,7 @@ void syncTabsData(HWND hwnd, DataSource *params) GET_STRING_TAB(SSL_TAB, SSL_CRL); GET_STRING_TAB(SSL_TAB, SSL_CRLPATH); - /* 11 - Misc*/ + /* 12 - Misc*/ GET_BOOL_TAB(MISC_TAB, SAFE); GET_BOOL_TAB(MISC_TAB, NO_LOCALE); GET_BOOL_TAB(MISC_TAB, IGNORE_SPACE); @@ -501,7 +509,28 @@ void syncTabs(HWND hwnd, DataSource *params) SET_UNSIGNED_TAB(FED_AUTH_TAB, CLIENT_SOCKET_TIMEOUT); SET_BOOL_TAB(FED_AUTH_TAB, ENABLE_SSL); - /* 5 - Failover */ + /* 5 - Custom Endpoint */ + SET_BOOL_TAB(CUSTOM_ENDPOINT_TAB, ENABLE_CUSTOM_ENDPOINT_MONITORING); + SET_BOOL_TAB(CUSTOM_ENDPOINT_TAB, WAIT_FOR_CUSTOM_ENDPOINT_INFO); + + if (params->opt_CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS > 0) + { + SET_UNSIGNED_TAB(CUSTOM_ENDPOINT_TAB, CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS); + } + + if (params->opt_CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS > 0) + { + SET_UNSIGNED_TAB(CUSTOM_ENDPOINT_TAB, CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS); + } + + if (params->opt_WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS > 0) + { + SET_UNSIGNED_TAB(CUSTOM_ENDPOINT_TAB, WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS); + } + + SET_STRING_TAB(CUSTOM_ENDPOINT_TAB, CUSTOM_ENDPOINT_REGION); + + /* 6 - Failover */ SET_BOOL_TAB(FAILOVER_TAB, ENABLE_CLUSTER_FAILOVER); SET_COMBO_TAB(FAILOVER_TAB, FAILOVER_MODE); SET_BOOL_TAB(FAILOVER_TAB, GATHER_PERF_METRICS); @@ -552,7 +581,7 @@ void syncTabs(HWND hwnd, DataSource *params) SET_UNSIGNED_TAB(FAILOVER_TAB, NETWORK_TIMEOUT); } - /* 6 - Monitoring */ + /* 7 - Monitoring */ SET_BOOL_TAB(MONITORING_TAB, ENABLE_FAILURE_DETECTION); if (READ_BOOL_TAB(MONITORING_TAB, ENABLE_FAILURE_DETECTION)) { #ifdef _WIN32 @@ -569,7 +598,7 @@ void syncTabs(HWND hwnd, DataSource *params) SET_UNSIGNED_TAB(MONITORING_TAB, FAILURE_DETECTION_TIMEOUT); } - /* 7 - Metadata */ + /* 8 - Metadata */ SET_BOOL_TAB(METADATA_TAB, NO_BIGINT); SET_BOOL_TAB(METADATA_TAB, NO_BINARY_RESULT); SET_BOOL_TAB(METADATA_TAB, FULL_COLUMN_NAMES); @@ -577,7 +606,7 @@ void syncTabs(HWND hwnd, DataSource *params) SET_BOOL_TAB(METADATA_TAB, NO_SCHEMA); SET_BOOL_TAB(METADATA_TAB, COLUMN_SIZE_S32); - /* 8 - Cursors/Results */ + /* 9 - Cursors/Results */ SET_BOOL_TAB(CURSORS_TAB, FOUND_ROWS); SET_BOOL_TAB(CURSORS_TAB, AUTO_IS_NULL); SET_BOOL_TAB(CURSORS_TAB, DYNAMIC_CURSOR); @@ -596,10 +625,10 @@ void syncTabs(HWND hwnd, DataSource *params) SET_UNSIGNED_TAB(CURSORS_TAB, PREFETCH); } - /* 9 - debug*/ + /* 10 - debug*/ SET_BOOL_TAB(DEBUG_TAB,LOG_QUERY); - /* 10 - ssl related */ + /* 11 - ssl related */ #ifdef _WIN32 if ( getTabCtrlTabPages(SSL_TAB-1) ) #endif @@ -637,7 +666,7 @@ void syncTabs(HWND hwnd, DataSource *params) SET_STRING_TAB(SSL_TAB, TLS_VERSIONS); } - /* 11 - Misc*/ + /* 12 - Misc*/ SET_BOOL_TAB(MISC_TAB, SAFE); SET_BOOL_TAB(MISC_TAB, NO_LOCALE); SET_BOOL_TAB(MISC_TAB, IGNORE_SPACE); diff --git a/setupgui/setupgui.h b/setupgui/setupgui.h index 903476b05..f9bd3c1a3 100644 --- a/setupgui/setupgui.h +++ b/setupgui/setupgui.h @@ -39,13 +39,14 @@ #define AUTH_TAB 2 #define AWS_AUTH_TAB 3 #define FED_AUTH_TAB 4 -#define FAILOVER_TAB 5 -#define MONITORING_TAB 6 -#define METADATA_TAB 7 -#define CURSORS_TAB 8 -#define DEBUG_TAB 9 -#define SSL_TAB 10 -#define MISC_TAB 11 +#define CUSTOM_ENDPOINT_TAB 5 +#define FAILOVER_TAB 6 +#define MONITORING_TAB 7 +#define METADATA_TAB 8 +#define CURSORS_TAB 9 +#define DEBUG_TAB 10 +#define SSL_TAB 11 +#define MISC_TAB 12 #else # include diff --git a/setupgui/windows/odbcdialogparams.cpp b/setupgui/windows/odbcdialogparams.cpp index 64bf5e332..174a0665a 100644 --- a/setupgui/windows/odbcdialogparams.cpp +++ b/setupgui/windows/odbcdialogparams.cpp @@ -376,6 +376,7 @@ void btnDetails_Click (HWND hwnd) L"Authentication", L"AWS Authentication", L"Federated Authentication", + L"Custom Endpoint Monitoring", L"Cluster Failover", L"Monitoring", L"Metadata", @@ -396,6 +397,7 @@ void btnDetails_Click (HWND hwnd) MAKEINTRESOURCE(IDD_TAB9), MAKEINTRESOURCE(IDD_TAB10), MAKEINTRESOURCE(IDD_TAB11), + MAKEINTRESOURCE(IDD_TAB12), 0}; New_TabControl( &TabCtrl_1, // address of TabControl struct diff --git a/setupgui/windows/odbcdialogparams.rc b/setupgui/windows/odbcdialogparams.rc index e0b6f325a..39dc04dd1 100644 --- a/setupgui/windows/odbcdialogparams.rc +++ b/setupgui/windows/odbcdialogparams.rc @@ -252,7 +252,25 @@ BEGIN CONTROL "&Enable SSL",IDC_CHECK_ENABLE_SSL,"Button",BS_AUTOCHECKBOX | WS_TABSTOP,207,108,47,10 END -IDD_TAB5 DIALOGEX 0, 0, 209, 281 +IDD_TAB5 DIALOGEX 0, 0, 209, 181 +STYLE DS_SETFONT | DS_FIXEDSYS | WS_CHILD +FONT 8, "MS Shell Dlg", 400, 0, 0x1 +BEGIN + CONTROL "&Enable custom endpoint monitoring",IDC_CHECK_ENABLE_CUSTOM_ENDPOINT_MONITORING, + "Button",BS_AUTOCHECKBOX | WS_TABSTOP,12,12,147,10 + CONTROL "&Wait for custom endpoint info",IDC_CHECK_WAIT_FOR_CUSTOM_ENDPOINT_INFO, + "Button",BS_AUTOCHECKBOX | WS_TABSTOP,12,27,147,10 + RTEXT "Custom endpoint info refresh rate (ms):",IDC_STATIC,12,42,150,10 + EDITTEXT IDC_EDIT_CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS,165,40,64,12,ES_AUTOHSCROLL | ES_NUMBER + RTEXT "Wait for custom endpoint info timeout (ms):",IDC_STATIC,12,57,150,8 + EDITTEXT IDC_EDIT_WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS,165,55,64,12,ES_AUTOHSCROLL | ES_NUMBER + RTEXT "Custom endpoint monitor expiration time (ms):",IDC_STATIC,12,72,150,8 + EDITTEXT IDC_EDIT_CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS,165,70,64,12,ES_AUTOHSCROLL | ES_NUMBER + RTEXT "Custom endpoint region:",IDC_STATIC,12,87,150,8 + EDITTEXT IDC_EDIT_CUSTOM_ENDPOINT_REGION,165,85,64,12,ES_AUTOHSCROLL +END + +IDD_TAB6 DIALOGEX 0, 0, 209, 281 STYLE DS_SETFONT | DS_FIXEDSYS | WS_CHILD FONT 8, "MS Shell Dlg", 400, 0, 0x1 BEGIN @@ -284,7 +302,7 @@ BEGIN "Button", BS_AUTOCHECKBOX | WS_TABSTOP | WS_DISABLED, 210, 96, 120, 10 END -IDD_TAB6 DIALOGEX 0, 0, 209, 181 +IDD_TAB7 DIALOGEX 0, 0, 209, 181 STYLE DS_SETFONT | DS_FIXEDSYS | WS_CHILD FONT 8, "MS Shell Dlg", 400, 0, 0x1 BEGIN @@ -302,7 +320,7 @@ BEGIN EDITTEXT IDC_EDIT_MONITOR_DISPOSAL_TIME,132,85,64,12,ES_AUTOHSCROLL | ES_NUMBER| WS_DISABLED END -IDD_TAB7 DIALOGEX 0, 0, 209, 181 +IDD_TAB8 DIALOGEX 0, 0, 209, 181 STYLE DS_SETFONT | DS_FIXEDSYS | WS_CHILD FONT 8, "MS Shell Dlg", 400, 0, 0x1 BEGIN @@ -320,7 +338,7 @@ BEGIN "Button",BS_AUTOCHECKBOX | WS_TABSTOP,12,87,141,10 END -IDD_TAB8 DIALOGEX 0, 0, 209, 181 +IDD_TAB9 DIALOGEX 0, 0, 209, 181 STYLE DS_SETFONT | DS_FIXEDSYS | WS_CHILD FONT 8, "MS Shell Dlg", 400, 0, 0x1 BEGIN @@ -345,15 +363,15 @@ BEGIN "Button",BS_AUTOCHECKBOX | WS_TABSTOP,12,125,138,10 END -IDD_TAB9 DIALOGEX 0, 0, 209, 181 +IDD_TAB10 DIALOGEX 0, 0, 209, 181 STYLE DS_SETFONT | DS_FIXEDSYS | WS_CHILD FONT 8, "MS Shell Dlg", 400, 0, 0x1 BEGIN CONTROL "&Log driver activity to %TEMP%\\myodbc.log",IDC_CHECK_LOG_QUERY, - "Button",BS_AUTOCHECKBOX | WS_TABSTOP,12,12,148,10 + "Button",BS_AUTOCHECKBOX | WS_TABSTOP,12,12,1170,10 END -IDD_TAB10 DIALOGEX 0, 0, 509, 181 +IDD_TAB11 DIALOGEX 0, 0, 509, 181 STYLE DS_SETFONT | DS_FIXEDSYS | WS_CHILD FONT 8, "MS Shell Dlg", 400, 0, 0x1 BEGIN @@ -386,7 +404,7 @@ BEGIN CONTROL "Disable TLS Version 1.&3",IDC_CHECK_NO_TLS_1_3,"Button",BS_AUTOCHECKBOX | WS_TABSTOP,90,164,87,10 END -IDD_TAB11 DIALOGEX 0, 0, 209, 181 +IDD_TAB12 DIALOGEX 0, 0, 209, 181 STYLE DS_SETFONT | DS_FIXEDSYS | WS_CHILD FONT 8, "MS Shell Dlg", 400, 0, 0x1 BEGIN diff --git a/setupgui/windows/resource.h b/setupgui/windows/resource.h index dcf7f33c3..d1dbcd73b 100644 --- a/setupgui/windows/resource.h +++ b/setupgui/windows/resource.h @@ -60,6 +60,7 @@ #define IDD_TAB9 140 #define IDD_TAB10 141 #define IDD_TAB11 142 +#define IDD_TAB12 143 #define IDC_LOGO 1000 #define IDC_EDIT 1010 #define IDC_EDIT_PASSWORD 1010 @@ -199,6 +200,12 @@ #define IDC_EDIT_FED_AUTH_HOST 11032 #define IDC_EDIT_FED_AUTH_PORT 11033 #define IDC_EDIT_FED_AUTH_EXPIRATION 11034 +#define IDC_CHECK_ENABLE_CUSTOM_ENDPOINT_MONITORING 11040 +#define IDC_CHECK_WAIT_FOR_CUSTOM_ENDPOINT_INFO 11041 +#define IDC_EDIT_CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS 11042 +#define IDC_EDIT_WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS 11043 +#define IDC_EDIT_CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS 11044 +#define IDC_EDIT_CUSTOM_ENDPOINT_REGION 11045 #define MYSQL_ADMIN_PORT 33062 #define IDC_STATIC -1 diff --git a/unit_testing/CMakeLists.txt b/unit_testing/CMakeLists.txt index 1d37f23fe..c8184dc78 100644 --- a/unit_testing/CMakeLists.txt +++ b/unit_testing/CMakeLists.txt @@ -57,6 +57,8 @@ add_executable( adfs_proxy_test.cc cluster_aware_metrics_test.cc + custom_endpoint_monitor_test.cc + custom_endpoint_proxy_test.cc efm_proxy_test.cc iam_proxy_test.cc failover_handler_test.cc diff --git a/unit_testing/custom_endpoint_monitor_test.cc b/unit_testing/custom_endpoint_monitor_test.cc new file mode 100644 index 000000000..e2bfd6f93 --- /dev/null +++ b/unit_testing/custom_endpoint_monitor_test.cc @@ -0,0 +1,107 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include +#include + +#include + +#include +#include + +#include "driver/custom_endpoint_monitor.h" +#include "test_utils.h" +#include "mock_objects.h" + +using namespace Aws::RDS; + +using ::testing::_; +using ::testing::Return; +using ::testing::StrEq; + +namespace { +const std::string WRITER_CLUSTER_URL{"writer.cluster-XYZ.us-east-1.rds.amazonaws.com"}; +const std::string CUSTOM_ENDPOINT_URL{"custom.cluster-custom-XYZ.us-east-1.rds.amazonaws.com"}; +const auto ENDPOINT = Aws::Utils::Json::JsonValue(CUSTOM_ENDPOINT_URL); + +const long long REFRESH_RATE_NANOS = 50000000; +} // namespace + +static SQLHENV env; +static Aws::SDKOptions sdk_options; + +class CustomEndpointMonitorTest : public testing::Test { + protected: + DBC* dbc; + DataSource* ds; + MOCK_CONNECTION_PROXY* mock_connection_proxy; + std::shared_ptr mock_rds_client; + std::shared_ptr mock_topology_service; + ctpl::thread_pool monitor_thread_pool; + + static void SetUpTestSuite() { Aws::InitAPI(sdk_options); } + + static void TearDownTestSuite() { + Aws::ShutdownAPI(sdk_options); + mysql_library_end(); + } + + void SetUp() override { + allocate_odbc_handles(env, dbc, ds); + mock_rds_client = std::make_shared(); + mock_connection_proxy = new MOCK_CONNECTION_PROXY(dbc, ds); + } + + void TearDown() override { + cleanup_odbc_handles(env, dbc, ds); + TEST_UTILS::get_custom_endpoint_cache().clear(); + delete mock_connection_proxy; + } +}; + +TEST_F(CustomEndpointMonitorTest, TestRun) { + Model::DBClusterEndpoint endpoint; + endpoint.AddStaticMembers(CUSTOM_ENDPOINT_URL); + std::vector endpoints{endpoint}; + + const auto expected_result = Model::DescribeDBClusterEndpointsResult().WithDBClusterEndpoints(endpoints); + const auto expected_outcome = Model::DescribeDBClusterEndpointsOutcome(expected_result); + + EXPECT_CALL(*mock_rds_client, + DescribeDBClusterEndpoints( + Property("GetDBClusterEndpointIdentifier", + &Model::DescribeDBClusterEndpointsRequest::GetDBClusterEndpointIdentifier, StrEq("custom")))) + .WillRepeatedly(Return(expected_outcome)); + + CUSTOM_ENDPOINT_MONITOR monitor(mock_topology_service, CUSTOM_ENDPOINT_URL, "custom", "us-east-1", REFRESH_RATE_NANOS, + monitor_thread_pool, true, mock_rds_client); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + monitor.stop(); +} diff --git a/unit_testing/custom_endpoint_proxy_test.cc b/unit_testing/custom_endpoint_proxy_test.cc new file mode 100644 index 000000000..c02e5bd22 --- /dev/null +++ b/unit_testing/custom_endpoint_proxy_test.cc @@ -0,0 +1,106 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include +#include +#include + +#include "test_utils.h" +#include "mock_objects.h" + +using ::testing::_; +using ::testing::Return; +using ::testing::StrEq; + +namespace { +const std::string WRITER_CLUSTER_URL{"writer.cluster-XYZ.us-east-1.rds.amazonaws.com"}; +const std::string CUSTOM_ENDPOINT_URL{"custom.cluster-custom-XYZ.us-east-1.rds.amazonaws.com"}; +} // namespace + +static SQLHENV env; +static Aws::SDKOptions options; + +class CustomEndpointProxyTest : public testing::Test { + protected: + DBC* dbc; + DataSource* ds; + MOCK_CONNECTION_PROXY* mock_connection_proxy; + ctpl::thread_pool monitor_thread_pool; + std::shared_ptr mock_monitor = + std::make_shared(monitor_thread_pool); + + static void SetUpTestSuite() {} + + static void TearDownTestSuite() { mysql_library_end(); } + + void SetUp() override { + allocate_odbc_handles(env, dbc, ds); + ds->opt_ENABLE_CUSTOM_ENDPOINT_MONITORING = true; + ds->opt_WAIT_FOR_CUSTOM_ENDPOINT_INFO = true; + ds->opt_CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS = 60000; + + mock_connection_proxy = new MOCK_CONNECTION_PROXY(dbc, ds); + } + + void TearDown() override { + TEST_UTILS::get_custom_endpoint_monitor_cache().release_resources(); + cleanup_odbc_handles(env, dbc, ds); + } +}; + +TEST_F(CustomEndpointProxyTest, TestConnect_MonitorNotCreatedIfNotCustomEndpointHost) { + TEST_CUSTOM_ENDPOINT_PROXY custom_endpoint_proxy(dbc, ds, mock_connection_proxy); + EXPECT_EQ(TEST_UTILS::get_custom_endpoint_monitor_cache().size(), 0); + custom_endpoint_proxy.connect(WRITER_CLUSTER_URL.c_str(), "", "", "", 3306, "", 0); + + EXPECT_EQ(TEST_UTILS::get_custom_endpoint_monitor_cache().size(), 0); + EXPECT_CALL(*mock_connection_proxy, connect(_, _, _, _, _, _, _)).Times(0); +} + +TEST_F(CustomEndpointProxyTest, TestConnect_MonitorCreated) { + TEST_CUSTOM_ENDPOINT_PROXY custom_endpoint_proxy(dbc, ds, mock_connection_proxy); + EXPECT_EQ(0, TEST_UTILS::get_custom_endpoint_monitor_cache().size()); + EXPECT_CALL(custom_endpoint_proxy, create_custom_endpoint_monitor(_)).WillOnce(Return(mock_monitor)); + EXPECT_CALL(*mock_connection_proxy, connect(_, _, _, _, _, _, _)).Times(1); + custom_endpoint_proxy.connect(CUSTOM_ENDPOINT_URL.c_str(), "", "", "", 3306, "", 0); + + EXPECT_EQ(TEST_UTILS::get_custom_endpoint_monitor_cache().size(), 1); +} + +TEST_F(CustomEndpointProxyTest, TestConnect_TimeoutWaitingForInfo) { + TEST_CUSTOM_ENDPOINT_PROXY custom_endpoint_proxy(dbc, ds, mock_connection_proxy); + ds->opt_WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS = 100; + EXPECT_EQ(TEST_UTILS::get_custom_endpoint_monitor_cache().size(), 0); + EXPECT_CALL(custom_endpoint_proxy, create_custom_endpoint_monitor(_)).WillOnce(Return(mock_monitor)); + ds->opt_WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS = 1; + custom_endpoint_proxy.connect(CUSTOM_ENDPOINT_URL.c_str(), "user", "pwd", "db", 3306, "", 0); + + EXPECT_EQ(TEST_UTILS::get_custom_endpoint_monitor_cache().size(), 1); + EXPECT_CALL(*mock_connection_proxy, connect(_, _, _, _, _, _, _)).Times(0); +} diff --git a/unit_testing/failover_handler_test.cc b/unit_testing/failover_handler_test.cc index f86f1ce2b..60c561bde 100644 --- a/unit_testing/failover_handler_test.cc +++ b/unit_testing/failover_handler_test.cc @@ -43,6 +43,7 @@ using ::testing::StrEq; namespace { const std::string US_EAST_REGION_CLUSTER = "database-test-name.cluster-XYZ.us-east-2.rds.amazonaws.com"; const std::string US_EAST_REGION_CLUSTER_READ_ONLY = "database-test-name.cluster-ro-XYZ.us-east-2.rds.amazonaws.com"; + const std::string US_EAST_REGION_INSTANCE = "instance-test-name.XYZ.us-east-2.rds.amazonaws.com"; const std::string US_EAST_REGION_PROXY = "proxy-test-name.proxy-XYZ.us-east-2.rds.amazonaws.com"; const std::string US_EAST_REGION_CUSTON_DOMAIN = "custom-test-name.cluster-custom-XYZ.us-east-2.rds.amazonaws.com"; @@ -424,6 +425,26 @@ TEST_F(FailoverHandlerTest, GetRdsClusterHostUrl) { EXPECT_EQ(std::string(), TEST_UTILS::get_rds_cluster_host_url(CHINA_REGION_CUSTON_DOMAIN)); } +TEST_F(FailoverHandlerTest, GetRdsClusterId) { + EXPECT_EQ("database-test-name", TEST_UTILS::get_rds_cluster_id(US_EAST_REGION_CLUSTER)); + EXPECT_EQ("database-test-name", TEST_UTILS::get_rds_cluster_id(US_EAST_REGION_CLUSTER_READ_ONLY)); + EXPECT_EQ(std::string(), TEST_UTILS::get_rds_cluster_id(US_EAST_REGION_INSTANCE)); + + EXPECT_EQ("proxy-test-name", TEST_UTILS::get_rds_cluster_id(US_EAST_REGION_PROXY)); + EXPECT_EQ("custom-test-name", TEST_UTILS::get_rds_cluster_id(US_EAST_REGION_CUSTON_DOMAIN)); + + EXPECT_EQ("database-test-name", TEST_UTILS::get_rds_cluster_id(CHINA_REGION_CLUSTER)); + EXPECT_EQ("database-test-name", TEST_UTILS::get_rds_cluster_id(CHINA_REGION_CLUSTER_READ_ONLY)); + EXPECT_EQ("proxy-test-name", TEST_UTILS::get_rds_cluster_id(CHINA_REGION_PROXY)); + EXPECT_EQ("custom-test-name", TEST_UTILS::get_rds_cluster_id(CHINA_REGION_CUSTON_DOMAIN)); +} + +TEST_F(FailoverHandlerTest, GetRdsInstanceId) { + EXPECT_EQ("instance-test-name", TEST_UTILS::get_rds_instance_id(US_EAST_REGION_INSTANCE)); + EXPECT_EQ(std::string(), TEST_UTILS::get_rds_instance_id(US_EAST_REGION_CLUSTER_READ_ONLY)); + EXPECT_EQ(std::string(), TEST_UTILS::get_rds_instance_id(US_EAST_REGION_CLUSTER)); +} + TEST_F(FailoverHandlerTest, ConnectToNewWriter) { std::string server = "my-cluster-name.cluster-XYZ.us-east-2.rds.amazonaws.com"; ds->opt_SERVER.set_remove_brackets((SQLWCHAR*)to_sqlwchar_string(server).c_str(), server.size()); diff --git a/unit_testing/mock_objects.h b/unit_testing/mock_objects.h index 80ebce772..94b18575b 100644 --- a/unit_testing/mock_objects.h +++ b/unit_testing/mock_objects.h @@ -35,6 +35,7 @@ #include #include "driver/connection_proxy.h" +#include "driver/custom_endpoint_proxy.h" #include "driver/failover.h" #include "driver/saml_http_client.h" #include "driver/monitor_thread_container.h" @@ -92,6 +93,7 @@ class MOCK_TOPOLOGY_SERVICE : public TOPOLOGY_SERVICE { MOCK_METHOD(void, set_cluster_id, (std::string)); MOCK_METHOD(void, set_cluster_instance_template, (std::shared_ptr)); MOCK_METHOD(std::shared_ptr, get_topology, (CONNECTION_PROXY*, bool)); + MOCK_METHOD(std::shared_ptr, get_filtered_topology, (CONNECTION_PROXY*, bool)); MOCK_METHOD(void, mark_host_down, (std::shared_ptr)); MOCK_METHOD(void, mark_host_up, (std::shared_ptr)); }; @@ -222,6 +224,13 @@ class MOCK_SECRETS_MANAGER_CLIENT : public Aws::SecretsManager::SecretsManagerCl MOCK_METHOD(Aws::SecretsManager::Model::GetSecretValueOutcome, GetSecretValue, (const Aws::SecretsManager::Model::GetSecretValueRequest&), (const)); }; +class MOCK_RDS_CLIENT : public Aws::RDS::RDSClient { +public: + MOCK_RDS_CLIENT() : RDSClient(){}; + + MOCK_METHOD(Aws::RDS::Model::DescribeDBClusterEndpointsOutcome, DescribeDBClusterEndpoints, (const Aws::RDS::Model::DescribeDBClusterEndpointsRequest&), (const)); +}; + class MOCK_AUTH_UTIL : public AUTH_UTIL { public: MOCK_AUTH_UTIL() : AUTH_UTIL() {}; @@ -234,4 +243,16 @@ class MOCK_SAML_HTTP_CLIENT : public SAML_HTTP_CLIENT { MOCK_METHOD(nlohmann::json, post, (const std::string&, const std::string&, const std::string&)); MOCK_METHOD(nlohmann::json, get, (const std::string&, const httplib::Headers&)); }; + +class TEST_CUSTOM_ENDPOINT_PROXY : public CUSTOM_ENDPOINT_PROXY { +public: + TEST_CUSTOM_ENDPOINT_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy) : CUSTOM_ENDPOINT_PROXY(dbc, ds, next_proxy) {}; + MOCK_METHOD(std::shared_ptr, create_custom_endpoint_monitor, (const long long refresh_rate_nanos), (override)); + static int get_monitor_size() { return monitors.size(); } +}; + +class MOCK_CUSTOM_ENDPOINT_MONITOR : public CUSTOM_ENDPOINT_MONITOR { + public: + MOCK_CUSTOM_ENDPOINT_MONITOR(ctpl::thread_pool& pool) : CUSTOM_ENDPOINT_MONITOR(pool) {}; +}; #endif /* __MOCKOBJECTS_H__ */ diff --git a/unit_testing/sliding_expiration_cache_test.cc b/unit_testing/sliding_expiration_cache_test.cc index b623c4b8a..e9364a613 100644 --- a/unit_testing/sliding_expiration_cache_test.cc +++ b/unit_testing/sliding_expiration_cache_test.cc @@ -145,7 +145,7 @@ TEST_F(SlidingExpirationCacheTest, ExpirationTimeUpdateGet) { TEST_F(SlidingExpirationCacheTest, GetCacheExpireThread) { SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD cache(cache_exp_short); - + cache.init_clean_up_thread(); EXPECT_EQ(0, cache.size()); cache.put(cache_key_a, cache_val_a, cache_exp_short); cache.put(cache_key_b, cache_val_b, cache_exp_long); diff --git a/unit_testing/test_utils.cc b/unit_testing/test_utils.cc index 3b96eb9bf..7a01cb977 100644 --- a/unit_testing/test_utils.cc +++ b/unit_testing/test_utils.cc @@ -29,6 +29,9 @@ #include "test_utils.h" +#include "driver/custom_endpoint_monitor.h" +#include "driver/custom_endpoint_proxy.h" + void allocate_odbc_handles(SQLHENV& env, DBC*& dbc, DataSource*& ds) { SQLHDBC hdbc = nullptr; @@ -155,6 +158,23 @@ std::string TEST_UTILS::get_rds_cluster_host_url(std::string host) { return RDS_UTILS::get_rds_cluster_host_url(host); } +std::string TEST_UTILS::get_rds_cluster_id(std::string host) { + return RDS_UTILS::get_rds_cluster_id(host); +} + +std::string TEST_UTILS::get_rds_instance_id(std::string host) { + return RDS_UTILS::get_rds_instance_id(host); +} + std::string TEST_UTILS::get_rds_instance_host_pattern(std::string host) { return RDS_UTILS::get_rds_instance_host_pattern(host); } + +CACHE_MAP>& TEST_UTILS::get_custom_endpoint_cache() { + return std::ref(CUSTOM_ENDPOINT_MONITOR::custom_endpoint_cache); +} + +SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD>& +TEST_UTILS::get_custom_endpoint_monitor_cache() { + return std::ref(CUSTOM_ENDPOINT_PROXY::monitors); +} diff --git a/unit_testing/test_utils.h b/unit_testing/test_utils.h index d7ce4f84f..f431ce8fc 100644 --- a/unit_testing/test_utils.h +++ b/unit_testing/test_utils.h @@ -30,46 +30,54 @@ #ifndef __TESTUTILS_H__ #define __TESTUTILS_H__ -#include "driver/auth_util.h" +#include "driver/auth_util.h" +#include "driver/cache_map.h" +#include "driver/custom_endpoint_info.h" +#include "driver/custom_endpoint_monitor.h" #include "driver/driver.h" #include "driver/failover.h" #include "driver/iam_proxy.h" -#include "driver/okta_proxy.h" #include "driver/monitor.h" #include "driver/monitor_thread_container.h" -#include "driver/secrets_manager_proxy.h" +#include "driver/okta_proxy.h" #include "driver/rds_utils.h" +#include "driver/secrets_manager_proxy.h" +#include "driver/sliding_expiration_cache_with_clean_up_thread.h" void allocate_odbc_handles(SQLHENV& env, DBC*& dbc, DataSource*& ds); void cleanup_odbc_handles(SQLHENV env, DBC*& dbc, DataSource*& ds, bool call_myodbc_end = false); class TEST_UTILS { -public: - static std::chrono::milliseconds get_connection_check_interval(std::shared_ptr monitor); - static CONNECTION_STATUS check_connection_status(std::shared_ptr monitor); - static void populate_monitor_map(std::shared_ptr container, - std::set node_keys, std::shared_ptr monitor); - static void populate_task_map(std::shared_ptr container, - std::shared_ptr monitor); - static bool has_monitor(std::shared_ptr container, std::string node_key); - static bool has_task(std::shared_ptr container, std::shared_ptr monitor); - static bool has_any_tasks(std::shared_ptr container); - static bool has_available_monitor(std::shared_ptr container); - static std::shared_ptr get_available_monitor(std::shared_ptr container); - static size_t get_map_size(std::shared_ptr container); - static std::list> get_contexts(std::shared_ptr monitor); - static std::string build_cache_key(const char* host, const char* region, unsigned int port, const char* user); - static bool token_cache_contains_key(std::unordered_map token_cache, std::string cache_key); - static std::map, Aws::Utils::Json::JsonValue>& get_secrets_cache(); - static bool try_parse_region_from_secret(std::string secret, std::string& region); - static bool is_dns_pattern_valid(std::string host); - static bool is_rds_dns(std::string host); - static bool is_rds_cluster_dns(std::string host); - static bool is_rds_proxy_dns(std::string host); - static bool is_rds_writer_cluster_dns(std::string host); - static bool is_rds_custom_cluster_dns(std::string host); - static std::string get_rds_cluster_host_url(std::string host); - static std::string get_rds_instance_host_pattern(std::string host); + public: + static std::chrono::milliseconds get_connection_check_interval(std::shared_ptr monitor); + static CONNECTION_STATUS check_connection_status(std::shared_ptr monitor); + static void populate_monitor_map(std::shared_ptr container, std::set node_keys, + std::shared_ptr monitor); + static void populate_task_map(std::shared_ptr container, std::shared_ptr monitor); + static bool has_monitor(std::shared_ptr container, std::string node_key); + static bool has_task(std::shared_ptr container, std::shared_ptr monitor); + static bool has_any_tasks(std::shared_ptr container); + static bool has_available_monitor(std::shared_ptr container); + static std::shared_ptr get_available_monitor(std::shared_ptr container); + static size_t get_map_size(std::shared_ptr container); + static std::list> get_contexts(std::shared_ptr monitor); + static std::string build_cache_key(const char* host, const char* region, unsigned int port, const char* user); + static bool token_cache_contains_key(std::unordered_map token_cache, std::string cache_key); + static std::map, Aws::Utils::Json::JsonValue>& get_secrets_cache(); + static bool try_parse_region_from_secret(std::string secret, std::string& region); + static bool is_dns_pattern_valid(std::string host); + static bool is_rds_dns(std::string host); + static bool is_rds_cluster_dns(std::string host); + static bool is_rds_proxy_dns(std::string host); + static bool is_rds_writer_cluster_dns(std::string host); + static bool is_rds_custom_cluster_dns(std::string host); + static std::string get_rds_cluster_host_url(std::string host); + static std::string get_rds_cluster_id(std::string host); + static std::string get_rds_instance_id(std::string host); + static std::string get_rds_instance_host_pattern(std::string host); + static CACHE_MAP>& get_custom_endpoint_cache(); + static SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD>& + get_custom_endpoint_monitor_cache(); }; #endif /* __TESTUTILS_H__ */ diff --git a/util/installer.cc b/util/installer.cc index e5978810a..004dec17a 100644 --- a/util/installer.cc +++ b/util/installer.cc @@ -284,6 +284,14 @@ static SQLWCHAR W_FAILURE_DETECTION_COUNT[] = { 'F', 'A', 'I', 'L', 'U', 'R', 'E static SQLWCHAR W_MONITOR_DISPOSAL_TIME[] = { 'M', 'O', 'N', 'I', 'T', 'O', 'R', '_', 'D', 'I', 'S', 'P', 'O', 'S', 'A', 'L', '_', 'T', 'I', 'M', 'E', 0 }; static SQLWCHAR W_FAILURE_DETECTION_TIMEOUT[] = { 'F', 'A', 'I', 'L', 'U', 'R', 'E', '_', 'D', 'E', 'T', 'E', 'C', 'T', 'I', 'O', 'N', '_', 'T', 'I', 'M', 'E', 'O', 'U', 'T', 0 }; +/* Custom Endpoint */ +static SQLWCHAR W_ENABLE_CUSTOM_ENDPOINT_MONITORING[] = {'E', 'N', 'A', 'B', 'L', 'E', '_', 'C', 'U', 'S', 'T', 'O', 'M', '_', 'E', 'N', 'D', 'P', 'O', 'I', 'N', 'T', '_', 'M', 'O', 'N', 'I', 'T', 'O', 'R', 'I', 'N', 'G', 0}; +static SQLWCHAR W_CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS[] = { 'C', 'U', 'S', 'T', 'O', 'M', '_', 'E', 'N', 'D', 'P', 'O', 'I', 'N', 'T', '_', 'I', 'N', 'F', 'O', '_', 'R', 'E', 'F', 'R', 'E', 'S', 'H', '_', 'R', 'A', 'T', 'E', '_', 'M', 'S', 0 }; +static SQLWCHAR W_WAIT_FOR_CUSTOM_ENDPOINT_INFO[] = { 'W', 'A', 'I', 'T', '_', 'F', 'O', 'R', '_', 'C', 'U', 'S', 'T', 'O', 'M', '_', 'E', 'N', 'D', 'P', 'O', 'I', 'N', 'T', '_', 'I', 'N', 'F', 'O', 0 }; +static SQLWCHAR W_WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS[] = { 'W', 'A', 'I', 'T', '_', 'F', 'O', 'R', '_', 'C', 'U', 'S', 'T', 'O', 'M', '_', 'E', 'N', 'D', 'P', 'O', 'I', 'N', 'T', '_', 'I', 'N', 'F', 'O', '_', 'T', 'I', 'M', 'E', 'O', 'U', 'T', '_', 'M', 'S', 0 }; +static SQLWCHAR W_CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS[] = { 'C', 'U', 'S', 'T', 'O', 'M', '_', 'E', 'N', 'D', 'P', 'O', 'I', 'N', 'T', '_', 'M', 'O', 'N', 'I', 'T', 'O', 'R', '_', 'E', 'X', 'P', 'I', 'R', 'A', 'T', 'I', 'O', 'N', '_', 'M', 'S', 0 }; +static SQLWCHAR W_CUSTOM_ENDPOINT_REGION[] = { 'C', 'U', 'S', 'T', 'O', 'M', '_', 'E', 'N', 'D', 'P', 'O', 'I', 'N', 'T', '_', 'R', 'E', 'G', 'I', 'O', 'N', 0 }; + /* DS_PARAM */ /* externally used strings */ const SQLWCHAR W_DRIVER_PARAM[]= {';', 'D', 'R', 'I', 'V', 'E', 'R', '=', 0}; @@ -341,7 +349,12 @@ SQLWCHAR *dsnparams[]= {W_DSN, W_DRIVER, W_DESCRIPTION, W_SERVER, /* Monitoring */ W_ENABLE_FAILURE_DETECTION, W_FAILURE_DETECTION_TIME, W_FAILURE_DETECTION_INTERVAL, W_FAILURE_DETECTION_COUNT, - W_MONITOR_DISPOSAL_TIME, W_FAILURE_DETECTION_TIMEOUT}; + W_MONITOR_DISPOSAL_TIME, W_FAILURE_DETECTION_TIMEOUT, + /* Custom Endpoints */ + W_ENABLE_CUSTOM_ENDPOINT_MONITORING, + W_CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS, W_WAIT_FOR_CUSTOM_ENDPOINT_INFO, + W_WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS, W_CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS, W_CUSTOM_ENDPOINT_REGION}; + static const int dsnparamcnt= sizeof(dsnparams) / sizeof(SQLWCHAR *); /* DS_PARAM */ @@ -675,7 +688,7 @@ int Driver::from_kvpair_semicolon(const SQLWCHAR *attrs) memcpy(attribute, attrs, (split - attrs) * sizeof(SQLWCHAR)); attribute[split - attrs]= 0; /* add null term */ ++split; - + /* if its one we want, copy it over */ if (!sqlwcharcasecmp(W_DRIVER, attribute)) dest = &lib; @@ -1060,6 +1073,11 @@ void DataSource::reset() { this->opt_MONITOR_DISPOSAL_TIME.set_default(MONITOR_DISPOSAL_TIME_MS); this->opt_FAILURE_DETECTION_TIMEOUT.set_default(FAILURE_DETECTION_TIMEOUT_SECS); + this->opt_WAIT_FOR_CUSTOM_ENDPOINT_INFO.set_default(true); + this->opt_WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS.set_default(5000); + this->opt_CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS.set_default(30000); + this->opt_CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS.set_default(900000); + this->opt_AUTH_PORT.set_default(opt_PORT); this->opt_AUTH_EXPIRATION.set_default(900); // 15 minutes this->opt_FED_AUTH_PORT.set_default(opt_PORT); diff --git a/util/installer.h b/util/installer.h index 7968cfc4b..1e17ec51d 100644 --- a/util/installer.h +++ b/util/installer.h @@ -364,49 +364,53 @@ unsigned int get_network_timeout(unsigned int seconds); X(FAILURE_DETECTION_TIMEOUT) \ X(MONITOR_DISPOSAL_TIME) +#define CUSTOM_ENDPOINT_BOOL_OPTIONS_LIST(X) X(WAIT_FOR_CUSTOM_ENDPOINT_INFO) \ + X(ENABLE_CUSTOM_ENDPOINT_MONITORING) + +#define CUSTOM_ENDPOINT_INT_OPTIONS_LIST(X) \ + X(CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS) \ + X(WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS) \ + X(CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS) + +#define CUSTOM_ENDPOINT_STR_OPTIONS_LIST(X) X(CUSTOM_ENDPOINT_REGION) + #define STR_OPTIONS_LIST(X) \ X(DSN) \ X(DRIVER) \ X(DESCRIPTION) \ - X(SERVER) \ - X(UID) \ - X(PWD) MFA_OPTS(X) X(DATABASE) X(SOCKET) X(INITSTMT) X(CHARSET) X(SSL_KEY) \ - X(SSL_CERT) X(SSL_CA) X(SSL_CAPATH) X(SSL_CIPHER) X(SSL_MODE) X(RSAKEY) \ - X(SAVEFILE) X(PLUGIN_DIR) X(DEFAULT_AUTH) X(LOAD_DATA_LOCAL_DIR) \ - X(OCI_CONFIG_FILE) X(OCI_CONFIG_PROFILE) \ - X(AUTHENTICATION_KERBEROS_MODE) X(TLS_VERSIONS) X(SSL_CRL) \ - X(SSL_CRLPATH) X(SSLVERIFY) X(OPENTELEMETRY) \ - AWS_AUTH_STR_OPTIONS_LIST(X) FAILOVER_STR_OPTIONS_LIST(X) FED_AUTH_STR_OPTIONS_LIST(X) - -#define INT_OPTIONS_LIST(X) \ - X(PORT) \ - X(READTIMEOUT) \ - X(WRITETIMEOUT) \ - X(CLIENT_INTERACTIVE) X(PREFETCH) FAILOVER_INT_OPTIONS_LIST(X) \ - AWS_AUTH_INT_OPTIONS_LIST(X) MONITORING_INT_OPTIONS_LIST(X) FED_AUTH_INT_OPTIONS_LIST(X) + X(SERVER) \ + X(UID) \ + X(PWD) \ + MFA_OPTS(X) X(DATABASE) X(SOCKET) X(INITSTMT) X(CHARSET) X(SSL_KEY) X(SSL_CERT) X(SSL_CA) X(SSL_CAPATH) \ + X(SSL_CIPHER) X(SSL_MODE) X(RSAKEY) X(SAVEFILE) X(PLUGIN_DIR) X(DEFAULT_AUTH) X(LOAD_DATA_LOCAL_DIR) \ + X(OCI_CONFIG_FILE) X(OCI_CONFIG_PROFILE) X(AUTHENTICATION_KERBEROS_MODE) X(TLS_VERSIONS) X(SSL_CRL) \ + X(SSL_CRLPATH) X(SSLVERIFY) X(OPENTELEMETRY) AWS_AUTH_STR_OPTIONS_LIST(X) FAILOVER_STR_OPTIONS_LIST(X) \ + CUSTOM_ENDPOINT_STR_OPTIONS_LIST(X) FED_AUTH_STR_OPTIONS_LIST(X) + +#define INT_OPTIONS_LIST(X) \ + X(PORT) \ + X(READTIMEOUT) \ + X(WRITETIMEOUT) \ + X(CLIENT_INTERACTIVE) \ + X(PREFETCH) FAILOVER_INT_OPTIONS_LIST(X) AWS_AUTH_INT_OPTIONS_LIST(X) MONITORING_INT_OPTIONS_LIST(X) \ + CUSTOM_ENDPOINT_INT_OPTIONS_LIST(X) FED_AUTH_INT_OPTIONS_LIST(X) // TODO: remove AUTO_RECONNECT when special handling (warning) // is not needed anymore. -#define BOOL_OPTIONS_LIST(X) \ - X(FOUND_ROWS) \ - X(BIG_PACKETS) \ - X(COMPRESSED_PROTO) \ - X(NO_BIGINT) \ - X(SAFE) \ - X(AUTO_RECONNECT) X(AUTO_IS_NULL) X(NO_BINARY_RESULT) X(CAN_HANDLE_EXP_PWD) \ - X(ENABLE_CLEARTEXT_PLUGIN) X(GET_SERVER_PUBLIC_KEY) X(NO_PROMPT) \ - X(DYNAMIC_CURSOR) X(NO_DEFAULT_CURSOR) X(NO_LOCALE) X(PAD_SPACE) \ - X(NO_CACHE) X(FULL_COLUMN_NAMES) X(IGNORE_SPACE) X(NAMED_PIPE) \ - X(NO_CATALOG) X(NO_SCHEMA) X(USE_MYCNF) X(NO_TRANSACTIONS) \ - X(FORWARD_CURSOR) X(MULTI_STATEMENTS) X(COLUMN_SIZE_S32) \ - X(MIN_DATE_TO_ZERO) X(ZERO_DATE_TO_MIN) X( \ - DFLT_BIGINT_BIND_STR) X(LOG_QUERY) X(NO_SSPS) \ - X(NO_TLS_1_2) X(NO_TLS_1_3) X(NO_DATE_OVERFLOW) \ - X(ENABLE_LOCAL_INFILE) X(ENABLE_DNS_SRV) \ - X(MULTI_HOST) \ - FAILOVER_BOOL_OPTIONS_LIST(X) \ - MONITORING_BOOL_OPTIONS_LIST(X) \ - FED_AUTH_BOOL_OPTIONS_LIST(X) +#define BOOL_OPTIONS_LIST(X) \ + X(FOUND_ROWS) \ + X(BIG_PACKETS) \ + X(COMPRESSED_PROTO) \ + X(NO_BIGINT) \ + X(SAFE) \ + X(AUTO_RECONNECT) \ + X(AUTO_IS_NULL) X(NO_BINARY_RESULT) X(CAN_HANDLE_EXP_PWD) X(ENABLE_CLEARTEXT_PLUGIN) X(GET_SERVER_PUBLIC_KEY) \ + X(NO_PROMPT) X(DYNAMIC_CURSOR) X(NO_DEFAULT_CURSOR) X(NO_LOCALE) X(PAD_SPACE) X(NO_CACHE) X(FULL_COLUMN_NAMES) \ + X(IGNORE_SPACE) X(NAMED_PIPE) X(NO_CATALOG) X(NO_SCHEMA) X(USE_MYCNF) X(NO_TRANSACTIONS) X(FORWARD_CURSOR) \ + X(MULTI_STATEMENTS) X(COLUMN_SIZE_S32) X(MIN_DATE_TO_ZERO) X(ZERO_DATE_TO_MIN) X(DFLT_BIGINT_BIND_STR) \ + X(LOG_QUERY) X(NO_SSPS) X(NO_TLS_1_2) X(NO_TLS_1_3) X(NO_DATE_OVERFLOW) X(ENABLE_LOCAL_INFILE) \ + X(ENABLE_DNS_SRV) X(MULTI_HOST) FAILOVER_BOOL_OPTIONS_LIST(X) MONITORING_BOOL_OPTIONS_LIST(X) \ + CUSTOM_ENDPOINT_BOOL_OPTIONS_LIST(X) FED_AUTH_BOOL_OPTIONS_LIST(X) #define FULL_OPTIONS_LIST(X) \ STR_OPTIONS_LIST(X) INT_OPTIONS_LIST(X) BOOL_OPTIONS_LIST(X)