From 4bbc8c94c265ff429642ab9f7871fe2806366bf1 Mon Sep 17 00:00:00 2001 From: Karen Chen Date: Mon, 13 Jan 2025 23:26:20 -0800 Subject: [PATCH] test: custom endpoints unit tests and integration tests --- .github/workflows/failover.yml | 1 + docs/using-the-aws-driver/CustomEndpoint.md | 24 + driver/CMakeLists.txt | 4 +- driver/cache_map.h | 4 +- driver/cluster_topology_info.cc | 22 + driver/cluster_topology_info.h | 2 + driver/custom_endpoint_info.h | 5 +- driver/custom_endpoint_monitor.cc | 80 +- driver/custom_endpoint_monitor.h | 24 +- driver/custom_endpoint_proxy.cc | 73 +- driver/custom_endpoint_proxy.h | 20 +- driver/driver.h | 6 + driver/failover_handler.cc | 3 +- driver/handle.cc | 11 +- driver/host_info.cc | 21 +- driver/host_info.h | 2 + driver/rds_utils.cc | 39 +- driver/rds_utils.h | 1 + driver/sliding_expiration_cache.cc | 6 +- ...g_expiration_cache_with_clean_up_thread.cc | 44 +- ...ng_expiration_cache_with_clean_up_thread.h | 14 +- driver/topology_service.cc | 44 +- driver/topology_service.h | 141 +-- integration/CMakeLists.txt | 1 + integration/base_failover_integration_test.cc | 18 +- integration/connection_string_builder.h | 840 ++++++++++-------- .../custom_endpoint_integration_test.cc | 178 ++++ unit_testing/CMakeLists.txt | 2 + unit_testing/custom_endpoint_monitor_test.cc | 109 +++ unit_testing/custom_endpoint_proxy_test.cc | 104 +++ unit_testing/failover_handler_test.cc | 21 + unit_testing/mock_objects.h | 20 + unit_testing/sliding_expiration_cache_test.cc | 2 +- unit_testing/test_utils.cc | 20 + unit_testing/test_utils.h | 66 +- 35 files changed, 1336 insertions(+), 636 deletions(-) create mode 100644 docs/using-the-aws-driver/CustomEndpoint.md create mode 100644 integration/custom_endpoint_integration_test.cc create mode 100644 unit_testing/custom_endpoint_monitor_test.cc create mode 100644 unit_testing/custom_endpoint_proxy_test.cc diff --git a/.github/workflows/failover.yml b/.github/workflows/failover.yml index de8c1e3df..05a7cdd13 100644 --- a/.github/workflows/failover.yml +++ b/.github/workflows/failover.yml @@ -1,6 +1,7 @@ name: Failover Unit Tests on: + workflow_dispatch: push: branches: - main diff --git a/docs/using-the-aws-driver/CustomEndpoint.md b/docs/using-the-aws-driver/CustomEndpoint.md new file mode 100644 index 000000000..03dcd7d54 --- /dev/null +++ b/docs/using-the-aws-driver/CustomEndpoint.md @@ -0,0 +1,24 @@ +# Custom Endpoint Support + +The Custom Endpoint support allows client application to use the driver with RDS custom endpoints. When the Custom Endpoint feature is enabled, the driver will analyse custom endpoint information to ensure instances used in connections are part of the custom endpoint being used. This includes connections used in failover. + +## How to use the Driver with Custom Endpoint + +### Enabling the Custom Endpoint Feature + +1. If needed, create a custom endpoint using the AWS RDS Console: + - If needed, review the documentation about [creating a custom endpoint](https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/aurora-custom-endpoint-creating.html). +2. Set `ENABLE_CUSTOM_ENDPOINT_MONITORING` to `TRUE` to enable custom endpoint support. +3. If you are using the failover plugin, set the failover parameter `FAILOVER_MODE` according to the custom endpoint type. For example, if the custom endpoint you are using is of type `READER`, you can set `FAILOVER_MODE` to `strict-reader`, or if it is of type `ANY`, you can set `FAILOVER_MODE` to `reader-or-writer`. +4. Specify parameters that are required or specific to your case. + +### Custom Endpoint Plugin Parameters + +| Parameter | Value | Required | Description | Default Value | Example Value | +| ------------------------------------------ | :----: | :------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------- | ------------- | +| `ENABLE_CUSTOM_ENDPOINT_MONITORING` | bool | No | Set to TRUE to enable custom endpoint support. | `FALSE` | `TRUE` | +| `CUSTOM_ENDPOINT_REGION` | string | No | The region of the cluster's custom endpoints. If not specified, the region will be parsed from the URL. | `N/A` | `us-west-1` | +| `CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS` | long | No | Controls how frequently custom endpoint monitors fetch custom endpoint info, in milliseconds. | `30000` | `20000` | +| `CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS` | long | No | Controls how long a monitor should run without use before expiring and being removed, in milliseconds. | `900000` (15 minutes) | `600000` | +| `WAIT_FOR_CUSTOM_ENDPOINT_INFO` | bool | No | Controls whether to wait for custom endpoint info to become available before connecting or executing a method. Waiting is only necessary if a connection to a given custom endpoint has not been opened or used recently. Note that disabling this may result in occasional connections to instances outside of the custom endpoint. | `true` | `true` | +| `WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS` | long | No | Controls the maximum amount of time that the plugin will wait for custom endpoint info to be made available by the custom endpoint monitor, in milliseconds. | `5000` | `7000` | diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt index cb172efbe..e0a5252cc 100644 --- a/driver/CMakeLists.txt +++ b/driver/CMakeLists.txt @@ -73,9 +73,9 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT}) connect.cc connection_handler.cc connection_proxy.cc - custom_endpoint_proxy.cc custom_endpoint_info.cc custom_endpoint_monitor.cc + custom_endpoint_proxy.cc cursor.cc desc.cc dll.cc @@ -148,9 +148,9 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT}) cluster_topology_info.h connection_handler.h connection_proxy.h - custom_endpoint_proxy.h custom_endpoint_info.h custom_endpoint_monitor.h + custom_endpoint_proxy.h driver.h efm_proxy.h error.h diff --git a/driver/cache_map.h b/driver/cache_map.h index 7e1e11753..f01adaecf 100644 --- a/driver/cache_map.h +++ b/driver/cache_map.h @@ -27,8 +27,8 @@ // along with this program. If not, see // http://www.gnu.org/licenses/gpl-2.0.html. -#ifndef __CACHE_MAP__ -#define __CACHE_MAP__ +#ifndef __CACHE_MAP_H__ +#define __CACHE_MAP_H__ #include #include diff --git a/driver/cluster_topology_info.cc b/driver/cluster_topology_info.cc index d33360d8a..db049caf7 100644 --- a/driver/cluster_topology_info.cc +++ b/driver/cluster_topology_info.cc @@ -30,6 +30,7 @@ #include "cluster_topology_info.h" #include +#include /** Initialize and return random number. @@ -75,6 +76,20 @@ void CLUSTER_TOPOLOGY_INFO::add_host(std::shared_ptr host_info) { update_time(); } +void CLUSTER_TOPOLOGY_INFO::remove_host(std::shared_ptr host_info) { + auto position = std::find(writers.begin(), writers.end(), host_info); + if (position != writers.end()) { + writers.erase(position); + } + + position = std::find(readers.begin(), readers.end(), host_info); + if (position != readers.end()) { + readers.erase(position); + } + update_time(); +} + + size_t CLUSTER_TOPOLOGY_INFO::total_hosts() { return writers.size() + readers.size(); } @@ -136,6 +151,13 @@ std::vector> CLUSTER_TOPOLOGY_INFO::get_writers() { return writers; } +std::vector> CLUSTER_TOPOLOGY_INFO::get_instances() { + std::vector instances(writers); + instances.insert(instances.end(), writers.begin(), writers.end()); + + return instances; +} + std::shared_ptr CLUSTER_TOPOLOGY_INFO::get_last_used_reader() { return last_used_reader; } diff --git a/driver/cluster_topology_info.h b/driver/cluster_topology_info.h index 90d840370..1e7271ee7 100644 --- a/driver/cluster_topology_info.h +++ b/driver/cluster_topology_info.h @@ -46,6 +46,7 @@ class CLUSTER_TOPOLOGY_INFO { virtual ~CLUSTER_TOPOLOGY_INFO(); void add_host(std::shared_ptr host_info); + void remove_host(std::shared_ptr host_info); size_t total_hosts(); size_t num_readers(); // return number of readers in the cluster std::time_t time_last_updated(); @@ -58,6 +59,7 @@ class CLUSTER_TOPOLOGY_INFO { std::shared_ptr get_reader(int i); std::vector> get_writers(); std::vector> get_readers(); + std::vector> get_instances(); private: int current_reader = -1; diff --git a/driver/custom_endpoint_info.h b/driver/custom_endpoint_info.h index fe67706f1..738be6b62 100644 --- a/driver/custom_endpoint_info.h +++ b/driver/custom_endpoint_info.h @@ -33,10 +33,7 @@ #include #include -#include -#include - -#include "MYODBC_MYSQL.h" +#include "stringutil.h" #include "mylog.h" /** diff --git a/driver/custom_endpoint_monitor.cc b/driver/custom_endpoint_monitor.cc index 8a9e0844a..9fdabec59 100644 --- a/driver/custom_endpoint_monitor.cc +++ b/driver/custom_endpoint_monitor.cc @@ -27,18 +27,17 @@ // along with this program. If not, see // http://www.gnu.org/licenses/gpl-2.0.html. -#include "custom_endpoint_monitor.h" - #include #include #include #include +#include #include #include "allowed_and_blocked_hosts.h" #include "aws_sdk_helper.h" +#include "custom_endpoint_monitor.h" #include "driver.h" -#include "monitor_service.h" #include "mylog.h" namespace { @@ -47,13 +46,16 @@ AWS_SDK_HELPER SDK_HELPER; CACHE_MAP> CUSTOM_ENDPOINT_MONITOR::custom_endpoint_cache; -CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr& custom_endpoint_host_info, +CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr topology_service, + const std::string& custom_endpoint_host, const std::string& endpoint_identifier, const std::string& region, - DataSource* ds, bool enable_logging) - : custom_endpoint_host_info(custom_endpoint_host_info), - endpoint_identifier(endpoint_identifier), - region(region), - enable_logging(enable_logging) { + long long refresh_rate_nanos, bool enable_logging) + : topology_service(topology_service), + custom_endpoint_host(custom_endpoint_host), + endpoint_identifier(endpoint_identifier), + region(region), + refresh_rate_nanos(refresh_rate_nanos), + enable_logging(enable_logging) { if (enable_logging) { this->logger = init_log_file(); } @@ -66,23 +68,42 @@ CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptrrds_client = std::make_shared( - Aws::Auth::DefaultAWSCredentialsProviderChain().GetAWSCredentials(), client_config); + Aws::Auth::DefaultAWSCredentialsProviderChain().GetAWSCredentials(), client_config); this->run(); } +#ifdef UNIT_TEST_BUILD +CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr topology_service, + const std::string& custom_endpoint_host, + const std::string& endpoint_identifier, const std::string& region, + long long refresh_rate_nanos, bool enable_logging, + std::shared_ptr client) + : topology_service(topology_service), + custom_endpoint_host(custom_endpoint_host), + endpoint_identifier(endpoint_identifier), + region(region), + refresh_rate_nanos(refresh_rate_nanos), + enable_logging(enable_logging), + rds_client(std::move(client)) { + if (enable_logging) { + this->logger = init_log_file(); + } + this->run(); +} +#endif + bool CUSTOM_ENDPOINT_MONITOR::should_dispose() { return true; } bool CUSTOM_ENDPOINT_MONITOR::has_custom_endpoint_info() const { auto default_val = std::shared_ptr(nullptr); - return custom_endpoint_cache.get(this->custom_endpoint_host_info->get_host(), default_val) != default_val; + return custom_endpoint_cache.get(this->custom_endpoint_host, default_val) != default_val; } void CUSTOM_ENDPOINT_MONITOR::run() { this->thread_pool.resize(1); this->thread_pool.push([=](int id) { - MYLOG_TRACE(this->logger, 0, "Starting custom endpoint monitor for '%s'", - this->custom_endpoint_host_info->get_host().c_str()); + MYLOG_TRACE(this->logger, 0, "Starting custom endpoint monitor for '%s'", this->custom_endpoint_host.c_str()); try { while (!this->should_stop.load()) { @@ -92,7 +113,7 @@ void CUSTOM_ENDPOINT_MONITOR::run() { filter.SetValues({"custom"}); Aws::RDS::Model::DescribeDBClusterEndpointsRequest request; - request.SetDBClusterIdentifier(this->endpoint_identifier); + request.SetDBClusterEndpointIdentifier(this->endpoint_identifier); request.SetFilters({filter}); const auto response = this->rds_client->DescribeDBClusterEndpoints(request); @@ -108,37 +129,37 @@ void CUSTOM_ENDPOINT_MONITOR::run() { continue; } const std::shared_ptr endpoint_info = - CUSTOM_ENDPOINT_INFO::from_db_cluster_endpoint(custom_endpoints[0]); + CUSTOM_ENDPOINT_INFO::from_db_cluster_endpoint(custom_endpoints[0]); const std::shared_ptr cache_endpoint_info = - custom_endpoint_cache.get(this->custom_endpoint_host_info->get_host(), nullptr); + custom_endpoint_cache.get(this->custom_endpoint_host, nullptr); if (cache_endpoint_info != nullptr && cache_endpoint_info == endpoint_info) { const long long elapsed_time = - std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count(); + std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count(); std::this_thread::sleep_for( - std::chrono::nanoseconds(std::max(static_cast(0), this->refresh_rate_nanos - elapsed_time))); + std::chrono::nanoseconds(std::max(static_cast(0), this->refresh_rate_nanos - elapsed_time))); continue; } MYLOG_TRACE(this->logger, 0, "Detected change in custom endpoint info for '%s':\n{%s}", - custom_endpoint_host_info->get_host().c_str(), endpoint_info->to_string().c_str()); + custom_endpoint_host.c_str(), endpoint_info->to_string().c_str()); // The custom endpoint info has changed, so we need to update the set of allowed/blocked hosts. std::shared_ptr allowed_and_blocked_hosts; if (endpoint_info->get_member_list_type() == STATIC_LIST) { allowed_and_blocked_hosts = - std::make_shared(endpoint_info->get_static_members(), std::set()); + std::make_shared(endpoint_info->get_static_members(), std::set()); } else { - allowed_and_blocked_hosts = - std::make_shared(std::set(), endpoint_info->get_excluded_members()); + allowed_and_blocked_hosts = std::make_shared( + std::set(), endpoint_info->get_excluded_members()); } - custom_endpoint_cache.put(this->custom_endpoint_host_info->get_host(), endpoint_info, - CUSTOM_ENDPOINT_INFO_EXPIRATION_NANOS); + this->topology_service->set_allowed_and_blocked_hosts(allowed_and_blocked_hosts); + custom_endpoint_cache.put(this->custom_endpoint_host, endpoint_info, CUSTOM_ENDPOINT_INFO_EXPIRATION_NANOS); const long long elapsed_time = - std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count(); + std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count(); std::this_thread::sleep_for( - std::chrono::nanoseconds(std::max(static_cast(0), this->refresh_rate_nanos - elapsed_time))); + std::chrono::nanoseconds(std::max(static_cast(0), this->refresh_rate_nanos - elapsed_time))); } } catch (const std::exception& e) { @@ -149,7 +170,7 @@ void CUSTOM_ENDPOINT_MONITOR::run() { } std::string CUSTOM_ENDPOINT_MONITOR::get_endpoints_as_string( - const std::vector& custom_endpoints) { + const std::vector& custom_endpoints) { if (custom_endpoints.empty()) { return ""; } @@ -171,9 +192,10 @@ void CUSTOM_ENDPOINT_MONITOR::stop() { this->should_stop.store(true); this->thread_pool.stop(true); this->thread_pool.resize(0); - custom_endpoint_cache.remove(this->custom_endpoint_host_info->get_host()); + custom_endpoint_cache.remove(this->custom_endpoint_host); + this->rds_client.reset(); --SDK_HELPER; - MYLOG_TRACE(this->logger, 0, "Stopped custom endpoint monitor for '%s'", this->custom_endpoint_host_info->get_host().c_str()); + MYLOG_TRACE(this->logger, 0, "Stopped custom endpoint monitor for '%s'", this->custom_endpoint_host.c_str()); } void CUSTOM_ENDPOINT_MONITOR::clear_cache() { custom_endpoint_cache.clear(); } diff --git a/driver/custom_endpoint_monitor.h b/driver/custom_endpoint_monitor.h index e824b3d72..6c6f5b7ef 100644 --- a/driver/custom_endpoint_monitor.h +++ b/driver/custom_endpoint_monitor.h @@ -34,15 +34,23 @@ #include #include "cache_map.h" -#include "connection_handler.h" -#include "connection_proxy.h" #include "custom_endpoint_info.h" #include "host_info.h" +#include "topology_service.h" class CUSTOM_ENDPOINT_MONITOR : public std::enable_shared_from_this { public: - CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr& custom_endpoint_host_info, const std::string& endpoint_identifier, - const std::string& region, DataSource* ds, bool enable_logging = false); + CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr topology_service, + const std::string& custom_endpoint_host, const std::string& endpoint_identifier, + const std::string& region, long long refresh_rate_nanos, bool enable_logging = false); +#ifdef UNIT_TEST_BUILD + CUSTOM_ENDPOINT_MONITOR() = default; + CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr topology_service, + const std::string& custom_endpoint_host, const std::string& endpoint_identifier, + const std::string& region, long long refresh_rate_nanos, bool enable_logging, + std::shared_ptr client); +#endif + ~CUSTOM_ENDPOINT_MONITOR() = default; static bool should_dispose(); @@ -54,7 +62,7 @@ class CUSTOM_ENDPOINT_MONITOR : public std::enable_shared_from_this> custom_endpoint_cache; static constexpr long long CUSTOM_ENDPOINT_INFO_EXPIRATION_NANOS = 300000000000; // 5 minutes - std::shared_ptr custom_endpoint_host_info; + std::string custom_endpoint_host; std::string endpoint_identifier; std::string region; long long refresh_rate_nanos; @@ -63,9 +71,15 @@ class CUSTOM_ENDPOINT_MONITOR : public std::enable_shared_from_this rds_client; + std::shared_ptr topology_service; private: static std::string get_endpoints_as_string(const std::vector& custom_endpoints); + +#ifdef UNIT_TEST_BUILD + // Allows for testing private/protected methods + friend class TEST_UTILS; +#endif }; #endif diff --git a/driver/custom_endpoint_proxy.cc b/driver/custom_endpoint_proxy.cc index 777ae5bdb..041a6cffc 100644 --- a/driver/custom_endpoint_proxy.cc +++ b/driver/custom_endpoint_proxy.cc @@ -28,27 +28,32 @@ // http://www.gnu.org/licenses/gpl-2.0.html. #include "custom_endpoint_proxy.h" -#include "custom_endpoint_monitor.h" -#include "installer.h" #include "mylog.h" #include "rds_utils.h" SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD> - CUSTOM_ENDPOINT_PROXY::monitors(std::make_shared(), - std::make_shared(), CACHE_CLEANUP_RATE_NANO); + CUSTOM_ENDPOINT_PROXY::monitors(std::make_shared(), + std::make_shared(), CACHE_CLEANUP_RATE_NANO); + +bool CUSTOM_ENDPOINT_PROXY::is_monitor_cache_initialized(false); CUSTOM_ENDPOINT_PROXY::CUSTOM_ENDPOINT_PROXY(DBC* dbc, DataSource* ds) : CUSTOM_ENDPOINT_PROXY(dbc, ds, nullptr) {} CUSTOM_ENDPOINT_PROXY::CUSTOM_ENDPOINT_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy) - : CONNECTION_PROXY(dbc, ds) { + : CONNECTION_PROXY(dbc, ds) { this->next_proxy = next_proxy; + this->topology_service = dbc->get_topology_service(); if (ds->opt_LOG_QUERY) { this->logger = init_log_file(); } - this->should_wait_for_info = ds->opt_WAIT_FOR_CUSTOM_ENDPOINT_INFO; this->wait_on_cached_info_duration_ms = ds->opt_WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS; this->idle_monitor_expiration_ms = ds->opt_CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS; + + if (!is_monitor_cache_initialized) { + monitors.init_clean_up_thread(); + is_monitor_cache_initialized = true; + } } bool CUSTOM_ENDPOINT_PROXY::connect(const char* host, const char* user, const char* password, const char* database, @@ -71,8 +76,8 @@ bool CUSTOM_ENDPOINT_PROXY::connect(const char* host, const char* user, const ch : RDS_UTILS::get_rds_region(host); if (this->region.empty()) { this->set_custom_error_message( - "Unable to determine connection region. If you are using a non-standard RDS URL, please set the " - "'custom_endpoint_region' property"); + "Unable to determine connection region. If you are using a non-standard RDS URL, please set the " + "'custom_endpoint_region' property"); return false; } @@ -85,6 +90,26 @@ bool CUSTOM_ENDPOINT_PROXY::connect(const char* host, const char* user, const ch return this->next_proxy->connect(host, user, password, database, port, socket, flags); } +int CUSTOM_ENDPOINT_PROXY::query(const char* q) { + const std::shared_ptr monitor = create_monitor_if_absent(ds); + if (this->should_wait_for_info) { + // If needed, wait a short time for custom endpoint info to be discovered. + this->wait_for_custom_endpoint_info(monitor); + } + + return next_proxy->query(q); +} + +int CUSTOM_ENDPOINT_PROXY::real_query(const char* q, unsigned long length) { + const std::shared_ptr monitor = create_monitor_if_absent(ds); + if (this->should_wait_for_info) { + // If needed, wait a short time for custom endpoint info to be discovered. + this->wait_for_custom_endpoint_info(monitor); + } + + return next_proxy->real_query(q, length); +} + void CUSTOM_ENDPOINT_PROXY::wait_for_custom_endpoint_info(std::shared_ptr monitor) { bool has_custom_endpoint_info = monitor->has_custom_endpoint_info(); @@ -96,11 +121,11 @@ void CUSTOM_ENDPOINT_PROXY::wait_for_custom_endpoint_info(std::shared_ptrlogger, 0, "Custom endpoint info for '%s' was not found. Waiting %dms for the endpoint monitor to fetch info...", - this->custom_endpoint_host_info->get_host().c_str(), this->wait_on_cached_info_duration_ms) + this->custom_endpoint_host.c_str(), this->wait_on_cached_info_duration_ms) const auto wait_for_endpoint_info_timeout_nanos = - std::chrono::steady_clock::now() + std::chrono::duration_cast( - std::chrono::milliseconds(this->wait_on_cached_info_duration_ms)); + std::chrono::steady_clock::now() + std::chrono::duration_cast( + std::chrono::milliseconds(this->wait_on_cached_info_duration_ms)); while (!has_custom_endpoint_info && std::chrono::steady_clock::now() < wait_for_endpoint_info_timeout_nanos) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); @@ -110,24 +135,26 @@ void CUSTOM_ENDPOINT_PROXY::wait_for_custom_endpoint_info(std::shared_ptrwait_on_cached_info_duration_ms, this->custom_endpoint_host_info->get_host().c_str()); + buf, sizeof(buf), + "The custom endpoint plugin timed out after %ld ms while waiting for custom endpoint info for host %s.", + this->wait_on_cached_info_duration_ms, this->custom_endpoint_host.c_str()); set_custom_error_message(buf); } } +std::shared_ptr CUSTOM_ENDPOINT_PROXY::create_custom_endpoint_monitor( + const long long refresh_rate_nanos) { + return std::make_shared(this->topology_service, this->custom_endpoint_host, + this->custom_endpoint_id, this->region, refresh_rate_nanos); +} + std::shared_ptr CUSTOM_ENDPOINT_PROXY::create_monitor_if_absent(DataSource* ds) { - const auto refresh_rate_nanos = std::chrono::duration_cast( - std::chrono::milliseconds(ds->opt_CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS)) - .count(); + const long long refresh_rate_nanos = std::chrono::duration_cast( + std::chrono::milliseconds(ds->opt_CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS)) + .count(); return monitors.compute_if_absent( - this->custom_endpoint_host_info->get_host(), - [=](std::string key) { - return std::make_shared(this->custom_endpoint_host_info, this->custom_endpoint_id, - this->region, ds); - }, - refresh_rate_nanos); + this->custom_endpoint_host, + [=](std::string key) { return this->create_custom_endpoint_monitor(refresh_rate_nanos); }, refresh_rate_nanos); } diff --git a/driver/custom_endpoint_proxy.h b/driver/custom_endpoint_proxy.h index a3c96d77d..83e4d8aca 100644 --- a/driver/custom_endpoint_proxy.h +++ b/driver/custom_endpoint_proxy.h @@ -27,13 +27,13 @@ // along with this program. If not, see // http://www.gnu.org/licenses/gpl-2.0.html. -#ifndef __CUSTOM_ENDPOINT_PROXY__ -#define __CUSTOM_ENDPOINT_PROXY__ +#ifndef __CUSTOM_ENDPOINT_PROXY_H__ +#define __CUSTOM_ENDPOINT_PROXY_H__ -#include #include #include "connection_proxy.h" #include "custom_endpoint_monitor.h" +#include "driver.h" #include "sliding_expiration_cache_with_clean_up_thread.h" class CUSTOM_ENDPOINT_PROXY : public CONNECTION_PROXY { @@ -41,8 +41,10 @@ class CUSTOM_ENDPOINT_PROXY : public CONNECTION_PROXY { CUSTOM_ENDPOINT_PROXY(DBC* dbc, DataSource* ds); CUSTOM_ENDPOINT_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy); - bool connect(const char* host, const char* user, const char* password, const char* database, unsigned int port, - const char* socket, unsigned long flags) override; + bool connect(const char* host, const char* user, const char* password, const char* database, + unsigned int port, const char* socket, unsigned long flags) override; + int query(const char* q) override; + int real_query(const char* q, unsigned long length) override; class CUSTOM_ENDPOINTS_SHOULD_DISPOSE_FUNC : public SHOULD_DISPOSE_FUNC> { public: @@ -62,14 +64,15 @@ class CUSTOM_ENDPOINT_PROXY : public CONNECTION_PROXY { static constexpr long long CACHE_CLEANUP_RATE_NANO = 60000000000; protected: + static bool is_monitor_cache_initialized; std::string custom_endpoint_id; std::string region; std::string custom_endpoint_host; - std::shared_ptr custom_endpoint_host_info; std::shared_ptr rds_client; bool should_wait_for_info; long wait_on_cached_info_duration_ms; long idle_monitor_expiration_ms; + std::shared_ptr topology_service; static SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD> monitors; @@ -84,6 +87,11 @@ class CUSTOM_ENDPOINT_PROXY : public CONNECTION_PROXY { private: std::shared_ptr logger; + virtual std::shared_ptr create_custom_endpoint_monitor(long long refresh_rate_nanos); +#ifdef UNIT_TEST_BUILD + // Allows for testing private/protected methods + friend class TEST_UTILS; +#endif }; #endif diff --git a/driver/driver.h b/driver/driver.h index c41caad62..a52cee9ee 100644 --- a/driver/driver.h +++ b/driver/driver.h @@ -47,6 +47,7 @@ #include "connection_handler.h" #include "connection_proxy.h" +#include "topology_service.h" #include "failover.h" /* Disable _attribute__ on non-gcc compilers. */ @@ -627,6 +628,7 @@ struct DBC FAILOVER_HANDLER *fh = nullptr; /* Failover handler */ std::shared_ptr connection_handler = nullptr; + std::shared_ptr topology_service = nullptr; DBC(ENV *p_env); void free_explicit_descriptors(); @@ -639,6 +641,10 @@ struct DBC void execute_prep_stmt(MYSQL_STMT *pstmt, std::string &query, std::vector ¶m_bind, MYSQL_BIND *result_bind); void init_proxy_chain(DataSource *dsrc); + std::shared_ptr get_topology_service() { + return this->topology_service ? this->topology_service + : std::make_shared(this->id, ds ? ds->opt_LOG_QUERY : false); + } inline bool transactions_supported() { return connection_proxy->get_server_capabilities() & CLIENT_TRANSACTIONS; diff --git a/driver/failover_handler.cc b/driver/failover_handler.cc index f3233ac4d..9bf3786da 100644 --- a/driver/failover_handler.cc +++ b/driver/failover_handler.cc @@ -52,8 +52,7 @@ const char* MYSQL_READONLY_QUERY = "SELECT @@innodb_read_only AS is_reader"; FAILOVER_HANDLER::FAILOVER_HANDLER(DBC* dbc, DataSource* ds) : FAILOVER_HANDLER( - dbc, ds, dbc ? dbc->connection_handler : nullptr, - std::make_shared(dbc ? dbc->id : 0, ds ? ds->opt_LOG_QUERY : false), + dbc, ds, dbc ? dbc->connection_handler : nullptr, dbc->get_topology_service(), std::make_shared(dbc, ds)) {} FAILOVER_HANDLER::FAILOVER_HANDLER(DBC* dbc, DataSource* ds, diff --git a/driver/handle.cc b/driver/handle.cc index 125ae86cf..a54e00408 100644 --- a/driver/handle.cc +++ b/driver/handle.cc @@ -123,12 +123,14 @@ void DBC::close() // construct a proxy chain, example: iam->efm->mysql void DBC::init_proxy_chain(DataSource* dsrc) { - CONNECTION_PROXY *head = new MYSQL_PROXY(this, dsrc); + this->topology_service = std::make_shared(this->id, ds ? ds->opt_LOG_QUERY : false); + + CONNECTION_PROXY* head = new MYSQL_PROXY(this, dsrc); if (dsrc->opt_ENABLE_CUSTOM_ENDPOINT_MONITORING) { - CONNECTION_PROXY* custom_endpoint_proxy = new CUSTOM_ENDPOINT_PROXY(this, dsrc); - custom_endpoint_proxy->set_next_proxy(head); - head = custom_endpoint_proxy; + CONNECTION_PROXY* custom_endpoint_proxy = new CUSTOM_ENDPOINT_PROXY(this, dsrc); + custom_endpoint_proxy->set_next_proxy(head); + head = custom_endpoint_proxy; } if (dsrc->opt_ENABLE_FAILURE_DETECTION) { @@ -173,6 +175,7 @@ DBC::~DBC() if (env) env->remove_dbc(this); + this->topology_service.reset(); if (connection_proxy) delete connection_proxy; diff --git a/driver/host_info.cc b/driver/host_info.cc index ba9e3beba..139a5ad29 100644 --- a/driver/host_info.cc +++ b/driver/host_info.cc @@ -29,6 +29,8 @@ #include "host_info.h" +#include "rds_utils.h" + // TODO // the entire HOST_INFO needs to be reviewed based on needed interfaces and other objects like CLUSTER_TOPOLOGY_INFO // most/all of the HOST_INFO potentially could be internal to CLUSTER_TOPOLOGY_INFO and specfic information may be accessed @@ -45,27 +47,30 @@ HOST_INFO::HOST_INFO(const char* host, int port) : HOST_INFO(host, port, UP, false) {} HOST_INFO::HOST_INFO(std::string host, int port, HOST_STATE state, bool is_writer) - : host{ host }, port{ port }, host_state{ state }, is_writer{ is_writer } -{ -} + : host{host}, host_id{RDS_UTILS::get_rds_instance_id(host)}, port{port}, host_state{state}, is_writer{is_writer} {} // would need some checks for nulls HOST_INFO::HOST_INFO(const char* host, int port, HOST_STATE state, bool is_writer) - : host{ host }, port{ port }, host_state{ state }, is_writer{ is_writer } -{ -} + : host{host}, host_id{RDS_UTILS::get_rds_instance_id(host)}, port{port}, host_state{state}, is_writer{is_writer} {} HOST_INFO::~HOST_INFO() {} /** - * Returns the host. + * Returns the host endpoint. * - * @return the host + * @return the host endpoint */ std::string HOST_INFO::get_host() { return host; } +/** + * Returns the host name. + * + * @return the host name + */ +std::string HOST_INFO::get_host_id() { return host_id; } + /** * Returns the port. * diff --git a/driver/host_info.h b/driver/host_info.h index e5c64a420..6049cac05 100644 --- a/driver/host_info.h +++ b/driver/host_info.h @@ -49,6 +49,7 @@ class HOST_INFO { int get_port(); std::string get_host(); + std::string get_host_id(); std::string get_host_port_pair(); bool equal_host_port_pair(HOST_INFO& hi); HOST_STATE get_host_state(); @@ -69,6 +70,7 @@ class HOST_INFO { private: const std::string HOST_PORT_SEPARATOR = ":"; const std::string host; + const std::string host_id; const int port = NO_PORT; HOST_STATE host_state; diff --git a/driver/rds_utils.cc b/driver/rds_utils.cc index 5f674e677..a49667a0c 100644 --- a/driver/rds_utils.cc +++ b/driver/rds_utils.cc @@ -31,7 +31,7 @@ namespace { const std::regex AURORA_DNS_PATTERN( - R"#((.+)\.(proxy-|cluster-|cluster-ro-|cluster-custom-)?([a-zA-Z0-9]+\.[a-zA-Z0-9\-]+\.rds\.amazonaws\.com))#", + R"#((.+)\.(proxy-|cluster-|cluster-ro-|cluster-custom-)?([a-zA-Z0-9]+\.([a-zA-Z0-9\-]+)\.rds\.amazonaws\.com))#", std::regex_constants::icase); const std::regex AURORA_PROXY_DNS_PATTERN(R"#((.+)\.(proxy-)+([a-zA-Z0-9]+\.[a-zA-Z0-9\-]+\.rds\.amazonaws\.com))#", std::regex_constants::icase); @@ -136,30 +136,44 @@ std::string RDS_UTILS::get_rds_cluster_host_url(std::string host) { std::string RDS_UTILS::get_rds_cluster_id(std::string host) { auto f = [host](const std::regex pattern) { std::smatch m; - if (std::regex_search(host, m, pattern) && m.size() > 1) { + if (std::regex_search(host, m, pattern) && m.size() > 1 && !m.str(2).empty()) { return m.str(1); } return std::string(); }; - auto result = f(AURORA_CLUSTER_PATTERN); + auto result = f(AURORA_DNS_PATTERN); if (!result.empty()) { return result; } - return f(AURORA_CHINA_CLUSTER_PATTERN); + return f(AURORA_CHINA_DNS_PATTERN); } +std::string RDS_UTILS::get_rds_instance_id(std::string host) { + auto f = [host](const std::regex pattern) { + std::smatch m; + if (std::regex_search(host, m, pattern) && m.size() > 1 && m.str(2).empty()) { + return m.str(1); + } + return std::string(); + }; + + auto result = f(AURORA_DNS_PATTERN); + if (!result.empty()) { + return result; + } + + return f(AURORA_CHINA_DNS_PATTERN); +} std::string RDS_UTILS::get_rds_instance_host_pattern(std::string host) { auto f = [host](const std::regex pattern) { std::smatch m; - if (std::regex_search(host, m, pattern) && m.size() > 3) { - if (!m.str(3).empty()) { - std::string result("?."); - result.append(m.str(3)); + if (std::regex_search(host, m, pattern) && m.size() > 4 && !m.str(3).empty()) { + std::string result("?."); + result.append(m.str(3)); - return result; - } + return result; } return std::string(); }; @@ -174,7 +188,10 @@ std::string RDS_UTILS::get_rds_instance_host_pattern(std::string host) { std::string RDS_UTILS::get_rds_region(std::string host) { auto f = [host](const std::regex pattern) { - // TODO: implement region + std::smatch m; + if (std::regex_search(host, m, pattern) && m.size() > 4 && !m.str(4).empty()) { + return m.str(4); + } return std::string(); }; diff --git a/driver/rds_utils.h b/driver/rds_utils.h index 34ad0cfd9..4c4b1648c 100644 --- a/driver/rds_utils.h +++ b/driver/rds_utils.h @@ -47,6 +47,7 @@ class RDS_UTILS { static std::string get_rds_cluster_host_url(std::string host); static std::string get_rds_cluster_id(std::string host); static std::string get_rds_instance_host_pattern(std::string host); + static std::string get_rds_instance_id(std::string host); static std::string get_rds_region(std::string host); }; diff --git a/driver/sliding_expiration_cache.cc b/driver/sliding_expiration_cache.cc index 423246e7e..af56f0a1c 100644 --- a/driver/sliding_expiration_cache.cc +++ b/driver/sliding_expiration_cache.cc @@ -69,9 +69,9 @@ template V SLIDING_EXPIRATION_CACHE::compute_if_absent(K key, std::function mapping_function, long long item_expiration_nanos) { this->clean_up(); - auto cache_item = std::make_shared(mapping_function(key), - std::chrono::steady_clock::now() + std::chrono::nanoseconds(item_expiration_nanos)); - this->cache[key] = cache_item; + V item = mapping_function(key); + auto cache_item = std::make_shared(item, std::chrono::steady_clock::now() + std::chrono::nanoseconds(item_expiration_nanos)); + this->cache.emplace(key, cache_item); return cache_item->with_extend_expiration(item_expiration_nanos)->item; } diff --git a/driver/sliding_expiration_cache_with_clean_up_thread.cc b/driver/sliding_expiration_cache_with_clean_up_thread.cc index 7d6056466..cc43582c5 100644 --- a/driver/sliding_expiration_cache_with_clean_up_thread.cc +++ b/driver/sliding_expiration_cache_with_clean_up_thread.cc @@ -30,7 +30,6 @@ #include "sliding_expiration_cache_with_clean_up_thread.h" #include -#include #include "custom_endpoint_monitor.h" @@ -40,7 +39,6 @@ void SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::init_clean_up_thread() std::unique_lock lock(mutex_); if (!this->is_initialized) { this->clean_up_thread_pool.resize(this->clean_up_thread_pool.size() + 1); - this->clean_up_thread_pool.push([=](int id) { while (!should_stop) { const std::chrono::nanoseconds clean_up_interval = std::chrono::nanoseconds(this->clean_up_interval_nanos); @@ -61,42 +59,16 @@ void SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::init_clean_up_thread() } } -template -SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD() { - this->init_clean_up_thread(); -} - -template -SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD( - std::shared_ptr> should_dispose_func, - std::shared_ptr> item_disposal_func) - : SLIDING_EXPIRATION_CACHE(std::move(should_dispose_func), std::move(item_disposal_func)) { - this->init_clean_up_thread(); -} - -template -SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD( - std::shared_ptr> should_dispose_func, - std::shared_ptr> item_disposal_func, long long clean_up_interval_nanos) - : SLIDING_EXPIRATION_CACHE(std::move(should_dispose_func), std::move(item_disposal_func), clean_up_interval_nanos) { - this->init_clean_up_thread(); -} - -#ifdef UNIT_TEST_BUILD -template -SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD( - long long clean_up_interval_nanos) { - this->clean_up_interval_nanos = clean_up_interval_nanos; - this->init_clean_up_thread(); -} -#endif - template void SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD::release_resources() { - this->should_stop = true; - this->clean_up_thread_pool.stop(true); - this->clean_up_thread_pool.resize(0); - this->is_initialized = false; + std::unique_lock lock(mutex_); + { + this->should_stop = true; + this->clean_up_thread_pool.stop(true); + this->clean_up_thread_pool.resize(0); + this->is_initialized = false; + this->clear(); + } } template class SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD; diff --git a/driver/sliding_expiration_cache_with_clean_up_thread.h b/driver/sliding_expiration_cache_with_clean_up_thread.h index 807744347..dd0583f84 100644 --- a/driver/sliding_expiration_cache_with_clean_up_thread.h +++ b/driver/sliding_expiration_cache_with_clean_up_thread.h @@ -38,12 +38,12 @@ template class SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD : public SLIDING_EXPIRATION_CACHE { public: - SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD(); + SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD() = default; SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD(std::shared_ptr> should_dispose_func, - std::shared_ptr> item_disposal_func); + std::shared_ptr> item_disposal_func){}; SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD(std::shared_ptr> should_dispose_func, std::shared_ptr> item_disposal_func, - long long clean_up_interval_nanos); + long long clean_up_interval_nanos){}; ~SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD() = default; /** @@ -52,17 +52,17 @@ class SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD : public SLIDING_EXPIRATION_ void release_resources(); #ifdef UNIT_TEST_BUILD - SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD(long long clean_up_interval_nanos); + SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD(long long clean_up_interval_nanos) { + this->clean_up_interval_nanos = clean_up_interval_nanos; + }; #endif + void init_clean_up_thread(); protected: bool is_initialized = false; bool should_stop = false; std::mutex mutex_; ctpl::thread_pool clean_up_thread_pool; - - private: - void init_clean_up_thread(); }; #endif diff --git a/driver/topology_service.cc b/driver/topology_service.cc index 265dee136..7364b9ad2 100644 --- a/driver/topology_service.cc +++ b/driver/topology_service.cc @@ -164,26 +164,46 @@ std::shared_ptr TOPOLOGY_SERVICE::get_cached_topology() { return get_from_cache(); } -//TODO consider the return value -//Note to determine whether or not force_update succeeded one would compare +// TODO consider the return value +// Note to determine whether or not force_update succeeded one would compare // CLUSTER_TOPOLOGY_INFO->time_last_updated() prior and after the call if non-null information was given prior. std::shared_ptr TOPOLOGY_SERVICE::get_topology(CONNECTION_PROXY* connection, bool force_update) { - //TODO reconsider using this cache. It appears that we only store information for the current cluster Id. + // TODO reconsider using this cache. It appears that we only store information for the current cluster Id. // therefore instead of a map we can just keep CLUSTER_TOPOLOGY_INFO* topology_info member variable. - auto cached_topology = get_from_cache(); - if (!cached_topology - || force_update - || refresh_needed(cached_topology->time_last_updated())) - { - auto latest_topology = query_for_topology(connection); - if (latest_topology) { + auto topology = get_from_cache(); + if (!topology || force_update || refresh_needed(topology->time_last_updated())) { + if (auto latest_topology = query_for_topology(connection)) { put_to_cache(latest_topology); - return latest_topology; + topology = latest_topology; } } - return cached_topology; + if (!this->allowed_and_blocked_hosts) { + return topology; + } + + std::set allowed_list = this->allowed_and_blocked_hosts->get_allowed_host_ids(); + std::set blocked_list = this->allowed_and_blocked_hosts->get_blocked_host_ids(); + + const std::shared_ptr filtered_topology = topology; + if (allowed_list.size() > 0) { + for (const auto& host : topology->get_instances()) { + if (allowed_list.find(host->get_host_id()) != allowed_list.end()) { + filtered_topology->add_host(host); + } + } + } + + if (blocked_list.size() > 0) { + for (const auto& host : filtered_topology->get_instances()) { + // Remove blocked hosts from the filtered_topology. + if (blocked_list.find(host->get_host_id()) != blocked_list.end()) { + filtered_topology->remove_host(host); + } + } + } + return filtered_topology; } // TODO consider thread safety and usage of pointers diff --git a/driver/topology_service.h b/driver/topology_service.h index 60f4b6d4d..3464e8519 100644 --- a/driver/topology_service.h +++ b/driver/topology_service.h @@ -24,7 +24,7 @@ // See the GNU General Public License, version 2.0, for more details. // // You should have received a copy of the GNU General Public License -// along with this program. If not, see +// along with this program. If not, see // http://www.gnu.org/licenses/gpl-2.0.html. #ifndef __TOPOLOGYSERVICE_H__ @@ -36,16 +36,18 @@ #include "cluster_topology_info.h" #include "connection_proxy.h" -#include -#include #include #include +#include +#include +#include "allowed_and_blocked_hosts.h" // TODO - consider - do we really need miliseconds for refresh? - the default numbers here are already 30 seconds.000; #define DEFAULT_REFRESH_RATE_IN_MILLISECONDS 30000 #define WRITER_SESSION_ID "MASTER_SESSION_ID" -#define RETRIEVE_TOPOLOGY_SQL "SELECT SERVER_ID, SESSION_ID, LAST_UPDATE_TIMESTAMP, REPLICA_LAG_IN_MILLISECONDS \ +#define RETRIEVE_TOPOLOGY_SQL \ + "SELECT SERVER_ID, SESSION_ID, LAST_UPDATE_TIMESTAMP, REPLICA_LAG_IN_MILLISECONDS \ FROM information_schema.replica_host_status \ WHERE time_to_sec(timediff(now(), LAST_UPDATE_TIMESTAMP)) <= 300 \ ORDER BY LAST_UPDATE_TIMESTAMP DESC" @@ -54,69 +56,74 @@ static std::map> topology_ca static std::mutex topology_cache_mutex; class TOPOLOGY_SERVICE { -public: - TOPOLOGY_SERVICE(unsigned long dbc_id, bool enable_logging = false); - TOPOLOGY_SERVICE(const TOPOLOGY_SERVICE&); - virtual ~TOPOLOGY_SERVICE(); - - virtual void set_cluster_id(std::string cluster_id); - virtual void set_cluster_instance_template(std::shared_ptr host_template); //is this equivalent to setcluster_instance_host - - virtual std::shared_ptr get_topology( - CONNECTION_PROXY* connection, bool force_update = false); - std::shared_ptr get_cached_topology(); - - std::shared_ptr get_last_used_reader(); - void set_last_used_reader(std::shared_ptr reader); - std::set get_down_hosts(); - virtual void mark_host_down(std::shared_ptr host); - virtual void mark_host_up(std::shared_ptr host); - void set_refresh_rate(int refresh_rate); - void set_gather_metric(bool can_gather); - void clear_all(); - void clear(); - - // Property Keys - const std::string SESSION_ID = "TOPOLOGY_SERVICE_SESSION_ID"; - const std::string LAST_UPDATED = "TOPOLOGY_SERVICE_LAST_UPDATE_TIMESTAMP"; - const std::string REPLICA_LAG = "TOPOLOGY_SERVICE_REPLICA_LAG_IN_MILLISECONDS"; - const std::string INSTANCE_NAME = "TOPOLOGY_SERVICE_SERVER_ID"; - -private: - const int DEFAULT_CACHE_EXPIRE_MS = 5 * 60 * 1000; // 5 min - - const std::string GET_INSTANCE_NAME_SQL = "SELECT @@aurora_server_id"; - const std::string GET_INSTANCE_NAME_COL = "@@aurora_server_id"; - - const std::string FIELD_SERVER_ID = "SERVER_ID"; - const std::string FIELD_SESSION_ID = "SESSION_ID"; - const std::string FIELD_LAST_UPDATED = "LAST_UPDATE_TIMESTAMP"; - const std::string FIELD_REPLICA_LAG = "REPLICA_LAG_IN_MILLISECONDS"; - - std::shared_ptr logger = nullptr; - unsigned long dbc_id = 0; - -protected: - const int NO_CONNECTION_INDEX = -1; - int refresh_rate_in_ms; - - std::string cluster_id; - std::shared_ptr cluster_instance_host; - - std::shared_ptr metrics_container; - - bool refresh_needed(std::time_t last_updated); - std::shared_ptr query_for_topology(CONNECTION_PROXY* connection); - std::shared_ptr create_host(MYSQL_ROW& row); - std::string get_host_endpoint(const char* node_name); - static bool does_instance_exist( - std::map>& instances, - std::shared_ptr host_info); - - std::shared_ptr get_from_cache(); - void put_to_cache(std::shared_ptr topology_info); - - MYSQL_RES* try_execute_query(CONNECTION_PROXY* connection_proxy, const char* query); + public: + TOPOLOGY_SERVICE(unsigned long dbc_id, bool enable_logging = false); + TOPOLOGY_SERVICE(const TOPOLOGY_SERVICE&); + virtual ~TOPOLOGY_SERVICE(); + + virtual void set_cluster_id(std::string cluster_id); + virtual void set_cluster_instance_template( + std::shared_ptr host_template); // is this equivalent to set_cluster_instance_host + + virtual std::shared_ptr get_topology(CONNECTION_PROXY* connection, bool force_update = false); + + std::shared_ptr get_cached_topology(); + + std::shared_ptr get_last_used_reader(); + void set_last_used_reader(std::shared_ptr reader); + std::set get_down_hosts(); + virtual void mark_host_down(std::shared_ptr host); + virtual void mark_host_up(std::shared_ptr host); + void set_refresh_rate(int refresh_rate); + void set_gather_metric(bool can_gather); + void clear_all(); + void clear(); + void set_allowed_and_blocked_hosts(std::shared_ptr hosts) { + this->allowed_and_blocked_hosts = hosts; + }; + + // Property Keys + const std::string SESSION_ID = "TOPOLOGY_SERVICE_SESSION_ID"; + const std::string LAST_UPDATED = "TOPOLOGY_SERVICE_LAST_UPDATE_TIMESTAMP"; + const std::string REPLICA_LAG = "TOPOLOGY_SERVICE_REPLICA_LAG_IN_MILLISECONDS"; + const std::string INSTANCE_NAME = "TOPOLOGY_SERVICE_SERVER_ID"; + + private: + const int DEFAULT_CACHE_EXPIRE_MS = 5 * 60 * 1000; // 5 min + + const std::string GET_INSTANCE_NAME_SQL = "SELECT @@aurora_server_id"; + const std::string GET_INSTANCE_NAME_COL = "@@aurora_server_id"; + + const std::string FIELD_SERVER_ID = "SERVER_ID"; + const std::string FIELD_SESSION_ID = "SESSION_ID"; + const std::string FIELD_LAST_UPDATED = "LAST_UPDATE_TIMESTAMP"; + const std::string FIELD_REPLICA_LAG = "REPLICA_LAG_IN_MILLISECONDS"; + + std::shared_ptr logger = nullptr; + unsigned long dbc_id = 0; + + protected: + const int NO_CONNECTION_INDEX = -1; + int refresh_rate_in_ms; + + std::string cluster_id; + std::shared_ptr cluster_instance_host; + + std::shared_ptr metrics_container; + + std::shared_ptr allowed_and_blocked_hosts; + + bool refresh_needed(std::time_t last_updated); + std::shared_ptr query_for_topology(CONNECTION_PROXY* connection); + std::shared_ptr create_host(MYSQL_ROW& row); + std::string get_host_endpoint(const char* node_name); + static bool does_instance_exist(std::map>& instances, + std::shared_ptr host_info); + + std::shared_ptr get_from_cache(); + void put_to_cache(std::shared_ptr topology_info); + + MYSQL_RES* try_execute_query(CONNECTION_PROXY* connection_proxy, const char* query); }; #endif /* __TOPOLOGYSERVICE_H__ */ diff --git a/integration/CMakeLists.txt b/integration/CMakeLists.txt index fe716112a..733c42dd9 100644 --- a/integration/CMakeLists.txt +++ b/integration/CMakeLists.txt @@ -99,6 +99,7 @@ set(TEST_SOURCES base_failover_integration_test.cc connection_string_builder_test.cc) set(INTEGRATION_TESTS + custom_endpoint_integration_test.cc iam_authentication_integration_test.cc secrets_manager_integration_test.cc network_failover_integration_test.cc diff --git a/integration/base_failover_integration_test.cc b/integration/base_failover_integration_test.cc index 77390b649..3e1e429b4 100644 --- a/integration/base_failover_integration_test.cc +++ b/integration/base_failover_integration_test.cc @@ -37,7 +37,9 @@ #include #include #include +#include #include +#include #include #include @@ -49,14 +51,17 @@ #include #include #include +#include #include #include -#include -#include -#include +#if defined(__APPLE__) || defined(__linux__) #include +#include #include +#include +#include +#endif #include "connection_string_builder.h" #include "integration_test_utils.h" @@ -267,6 +272,13 @@ class BaseFailoverIntegrationTest : public testing::Test { } } + static Aws::RDS::Model::DBClusterEndpoint get_custom_endpoint_info(const Aws::RDS::RDSClient& client, const std::string& endpoint_id) { + Aws::RDS::Model::DescribeDBClusterEndpointsRequest request; + request.SetDBClusterEndpointIdentifier(endpoint_id); + const auto response = client.DescribeDBClusterEndpoints(request); + return response.GetResult().GetDBClusterEndpoints()[0]; + } + static Aws::RDS::Model::DBClusterMember get_DB_cluster_writer_instance(const Aws::RDS::RDSClient& client, const Aws::String& cluster_id) { Aws::RDS::Model::DBClusterMember instance; const Aws::RDS::Model::DBCluster cluster = get_DB_cluster(client, cluster_id); diff --git a/integration/connection_string_builder.h b/integration/connection_string_builder.h index 8eb450d25..f31cf1a6a 100644 --- a/integration/connection_string_builder.h +++ b/integration/connection_string_builder.h @@ -24,22 +24,22 @@ // See the GNU General Public License, version 2.0, for more details. // // You should have received a copy of the GNU General Public License -// along with this program. If not, see +// along with this program. If not, see // http://www.gnu.org/licenses/gpl-2.0.html. #ifndef __CONNECTIONSTRINGBUILDER_H__ #define __CONNECTIONSTRINGBUILDER_H__ +#include #include #include -#include class ConnectionStringBuilder; class ConnectionString { - public: - // friend class so the builder can access ConnectionString private attributes - friend class ConnectionStringBuilder; + public: + // friend class so the builder can access ConnectionString private attributes + friend class ConnectionStringBuilder; ConnectionString() : m_dsn(""), m_server(""), m_port(-1), m_uid(""), m_pwd(""), m_db(""), m_log_query(true), @@ -49,9 +49,9 @@ class ConnectionString { m_failure_detection_interval(-1), m_failure_detection_count(-1), m_monitor_disposal_time(-1), m_read_timeout(-1), m_write_timeout(-1), m_auth_mode(""), m_auth_region(""), m_auth_host(""), m_auth_port(-1), m_auth_expiration(-1), m_secret_id(""), - + is_set_uid(false), is_set_pwd(false), is_set_db(false), is_set_log_query(false), - is_set_failover_mode(false), + is_set_failover_mode(false), is_set_multi_statements(false), is_set_enable_cluster_failover(false), is_set_failover_timeout(false), is_set_connect_timeout(false), is_set_network_timeout(false), is_set_host_pattern(false), is_set_enable_failure_detection(false), is_set_failure_detection_time(false), is_set_failure_detection_timeout(false), @@ -59,417 +59,493 @@ class ConnectionString { is_set_read_timeout(false), is_set_write_timeout(false), is_set_auth_mode(false), is_set_auth_region(false), is_set_auth_host(false), is_set_auth_port(false), is_set_auth_expiration(false), is_set_secret_id(false) {}; - std::string get_connection_string() const { - char conn_in[4096] = "\0"; - int length = 0; - length += sprintf(conn_in, "DSN=%s;SERVER=%s;PORT=%d;", m_dsn.c_str(), m_server.c_str(), m_port); - - if (is_set_uid) { - length += sprintf(conn_in + length, "UID=%s;", m_uid.c_str()); - } - if (is_set_pwd) { - length += sprintf(conn_in + length, "PWD=%s;", m_pwd.c_str()); - } - if (is_set_db) { - length += sprintf(conn_in + length, "DATABASE=%s;", m_db.c_str()); - } - if (is_set_log_query) { - length += sprintf(conn_in + length, "LOG_QUERY=%d;", m_log_query ? 1 : 0); - } - if (is_set_failover_mode) { - length += sprintf(conn_in + length, "FAILOVER_MODE=%s;", m_failover_mode.c_str()); - } - if (is_set_multi_statements) { - length += sprintf(conn_in + length, "MULTI_STATEMENTS=%d;", m_multi_statements ? 1 : 0); - } - if (is_set_enable_cluster_failover) { - length += sprintf(conn_in + length, "ENABLE_CLUSTER_FAILOVER=%d;", m_enable_cluster_failover ? 1 : 0); - } - if (is_set_failover_timeout) { - length += sprintf(conn_in + length, "FAILOVER_TIMEOUT=%d;", m_failover_timeout); - } - if (is_set_connect_timeout) { - length += sprintf(conn_in + length, "CONNECT_TIMEOUT=%d;", m_connect_timeout); - } - if (is_set_network_timeout) { - length += sprintf(conn_in + length, "NETWORK_TIMEOUT=%d;", m_network_timeout); - } - if (is_set_host_pattern) { - length += sprintf(conn_in + length, "HOST_PATTERN=%s;", m_host_pattern.c_str()); - } - if (is_set_enable_failure_detection) { - length += sprintf(conn_in + length, "ENABLE_FAILURE_DETECTION=%d;", m_enable_failure_detection ? 1 : 0); - } - if (is_set_failure_detection_time) { - length += sprintf(conn_in + length, "FAILURE_DETECTION_TIME=%d;", m_failure_detection_time); - } - if (is_set_failure_detection_timeout) { - length += sprintf(conn_in + length, "FAILURE_DETECTION_TIMEOUT=%d;", m_failure_detection_timeout); - } - if (is_set_failure_detection_interval) { - length += sprintf(conn_in + length, "FAILURE_DETECTION_INTERVAL=%d;", m_failure_detection_interval); - } - if (is_set_failure_detection_count) { - length += sprintf(conn_in + length, "FAILURE_DETECTION_COUNT=%d;", m_failure_detection_count); - } - if (is_set_monitor_disposal_time) { - length += sprintf(conn_in + length, "MONITOR_DISPOSAL_TIME=%d;", m_monitor_disposal_time); - } - if (is_set_read_timeout) { - length += sprintf(conn_in + length, "READTIMEOUT=%d;", m_read_timeout); - } - if (is_set_write_timeout) { - length += sprintf(conn_in + length, "WRITETIMEOUT=%d;", m_write_timeout); - } - if (is_set_auth_mode) { - length += sprintf(conn_in + length, "AUTHENTICATION_MODE=%s;", m_auth_mode.c_str()); - } - if (is_set_auth_region) { - length += sprintf(conn_in + length, "AWS_REGION=%s;", m_auth_region.c_str()); - } - if (is_set_auth_host) { - length += sprintf(conn_in + length, "IAM_HOST=%s;", m_auth_host.c_str()); - } - if (is_set_auth_port) { - length += sprintf(conn_in + length, "IAM_PORT=%d;", m_auth_port); - } - if (is_set_auth_expiration) { - length += sprintf(conn_in + length, "IAM_EXPIRATION_TIME=%d;", m_auth_expiration); - } - if (is_set_secret_id) { - length += sprintf(conn_in + length, "SECRET_ID=%s;", m_secret_id.c_str()); - } - snprintf(conn_in + length, sizeof(conn_in) - length, "\0"); - - std::string connection_string(conn_in); - return connection_string; - } - - private: - // Required fields - std::string m_dsn, m_server; - int m_port; - - // Optional fields - std::string m_uid, m_pwd, m_db; - bool m_log_query, m_multi_statements, m_enable_cluster_failover; - int m_failover_timeout, m_connect_timeout, m_network_timeout; - std::string m_host_pattern, m_failover_mode; - bool m_enable_failure_detection; - int m_failure_detection_time, m_failure_detection_timeout, m_failure_detection_interval, m_failure_detection_count, m_monitor_disposal_time, m_read_timeout, m_write_timeout; - std::string m_auth_mode, m_auth_region, m_auth_host, m_secret_id; - int m_auth_port, m_auth_expiration; - - bool is_set_uid, is_set_pwd, is_set_db; - bool is_set_log_query, is_set_failover_mode, is_set_multi_statements; - bool is_set_enable_cluster_failover; - bool is_set_failover_timeout, is_set_connect_timeout, is_set_network_timeout; - bool is_set_host_pattern; - bool is_set_enable_failure_detection; - bool is_set_failure_detection_time, is_set_failure_detection_timeout, is_set_failure_detection_interval, is_set_failure_detection_count; - bool is_set_monitor_disposal_time; - bool is_set_read_timeout, is_set_write_timeout; - bool is_set_auth_mode, is_set_auth_region, is_set_auth_host, is_set_auth_port, is_set_auth_expiration, is_set_secret_id; - - void set_dsn(const std::string& dsn) { - m_dsn = dsn; - } - - void set_server(const std::string& server) { - m_server = server; - } - - void set_port(const int& port) { - m_port = port; - } - - void set_uid(const std::string& uid) { - m_uid = uid; - is_set_uid = true; - } + std::string get_connection_string() const { + char conn_in[4096] = "\0"; + int length = 0; + length += sprintf(conn_in, "DSN=%s;SERVER=%s;PORT=%d;", m_dsn.c_str(), m_server.c_str(), m_port); - void set_pwd(const std::string& pwd) { - m_pwd = pwd; - is_set_pwd = true; + if (is_set_uid) { + length += sprintf(conn_in + length, "UID=%s;", m_uid.c_str()); } - - void set_db(const std::string& db) { - m_db = db; - is_set_db = true; + if (is_set_pwd) { + length += sprintf(conn_in + length, "PWD=%s;", m_pwd.c_str()); } - - void set_log_query(const bool& log_query) { - m_log_query = log_query; - is_set_log_query = true; - } - - void set_failover_mode(const std::string& failover_mode) { - m_failover_mode = failover_mode; - is_set_failover_mode = true; + if (is_set_db) { + length += sprintf(conn_in + length, "DATABASE=%s;", m_db.c_str()); } - - void set_multi_statements(const bool& multi_statements) { - m_multi_statements = multi_statements; - is_set_multi_statements = true; + if (is_set_log_query) { + length += sprintf(conn_in + length, "LOG_QUERY=%d;", m_log_query ? 1 : 0); } - - void set_enable_cluster_failover(const bool& enable_cluster_failover) { - m_enable_cluster_failover = enable_cluster_failover; - is_set_enable_cluster_failover = true; + if (is_set_failover_mode) { + length += sprintf(conn_in + length, "FAILOVER_MODE=%s;", m_failover_mode.c_str()); } - - void set_failover_timeout(const int& failover_timeout) { - m_failover_timeout = failover_timeout; - is_set_failover_timeout = true; + if (is_set_multi_statements) { + length += sprintf(conn_in + length, "MULTI_STATEMENTS=%d;", m_multi_statements ? 1 : 0); } - - void set_connect_timeout(const int& connect_timeout) { - m_connect_timeout = connect_timeout; - is_set_connect_timeout = true; + if (is_set_enable_cluster_failover) { + length += sprintf(conn_in + length, "ENABLE_CLUSTER_FAILOVER=%d;", m_enable_cluster_failover ? 1 : 0); } - - void set_network_timeout(const int& network_timeout) { - m_network_timeout = network_timeout; - is_set_network_timeout = true; + if (is_set_failover_timeout) { + length += sprintf(conn_in + length, "FAILOVER_TIMEOUT=%d;", m_failover_timeout); } - - void set_host_pattern(const std::string& host_pattern) { - m_host_pattern = host_pattern; - is_set_host_pattern = true; + if (is_set_connect_timeout) { + length += sprintf(conn_in + length, "CONNECT_TIMEOUT=%d;", m_connect_timeout); } - - void set_enable_failure_detection(const bool& enable_failure_detection) { - m_enable_failure_detection = enable_failure_detection; - is_set_enable_failure_detection = true; + if (is_set_network_timeout) { + length += sprintf(conn_in + length, "NETWORK_TIMEOUT=%d;", m_network_timeout); } - - void set_failure_detection_time(const int& failure_detection_time) { - m_failure_detection_time = failure_detection_time; - is_set_failure_detection_time = true; + if (is_set_host_pattern) { + length += sprintf(conn_in + length, "HOST_PATTERN=%s;", m_host_pattern.c_str()); } - - void set_failure_detection_timeout(const int& failure_detection_timeout) { - m_failure_detection_timeout = failure_detection_timeout; - is_set_failure_detection_timeout = true; + if (is_set_enable_failure_detection) { + length += sprintf(conn_in + length, "ENABLE_FAILURE_DETECTION=%d;", m_enable_failure_detection ? 1 : 0); } - - void set_failure_detection_interval(const int& failure_detection_interval) { - m_failure_detection_interval = failure_detection_interval; - is_set_failure_detection_interval = true; + if (is_set_failure_detection_time) { + length += sprintf(conn_in + length, "FAILURE_DETECTION_TIME=%d;", m_failure_detection_time); } - - void set_failure_detection_count(const int& failure_detection_count) { - m_failure_detection_count = failure_detection_count; - is_set_failure_detection_count = true; + if (is_set_failure_detection_timeout) { + length += sprintf(conn_in + length, "FAILURE_DETECTION_TIMEOUT=%d;", m_failure_detection_timeout); } - - void set_monitor_disposal_time(const int& monitor_disposal_time) { - m_monitor_disposal_time = monitor_disposal_time; - is_set_monitor_disposal_time = true; + if (is_set_failure_detection_interval) { + length += sprintf(conn_in + length, "FAILURE_DETECTION_INTERVAL=%d;", m_failure_detection_interval); } - - void set_read_timeout(const int& read_timeout) { - m_read_timeout = read_timeout; - is_set_read_timeout = true; + if (is_set_failure_detection_count) { + length += sprintf(conn_in + length, "FAILURE_DETECTION_COUNT=%d;", m_failure_detection_count); } - - void set_write_timeout(const int& write_timeout) { - m_write_timeout = write_timeout; - is_set_write_timeout = true; + if (is_set_monitor_disposal_time) { + length += sprintf(conn_in + length, "MONITOR_DISPOSAL_TIME=%d;", m_monitor_disposal_time); } - - void set_auth_mode(const std::string& auth_mode) { - m_auth_mode = auth_mode; - is_set_auth_mode = true; + if (is_set_read_timeout) { + length += sprintf(conn_in + length, "READTIMEOUT=%d;", m_read_timeout); } - - void set_auth_region(const std::string& auth_region) { - m_auth_region = auth_region; - is_set_auth_region = true; + if (is_set_write_timeout) { + length += sprintf(conn_in + length, "WRITETIMEOUT=%d;", m_write_timeout); } - - void set_auth_host(const std::string& auth_host) { - m_auth_host = auth_host; - is_set_auth_host = true; + if (is_set_auth_mode) { + length += sprintf(conn_in + length, "AUTHENTICATION_MODE=%s;", m_auth_mode.c_str()); } - - void set_auth_port(const int& auth_port) { - m_auth_port = auth_port; - is_set_auth_port = true; + if (is_set_auth_region) { + length += sprintf(conn_in + length, "AWS_REGION=%s;", m_auth_region.c_str()); } - - void set_auth_expiration(const int& auth_expiration) { - m_auth_expiration = auth_expiration; - is_set_auth_expiration = true; + if (is_set_auth_host) { + length += sprintf(conn_in + length, "IAM_HOST=%s;", m_auth_host.c_str()); } - - void set_secret_id(const std::string& secret_id) { - m_secret_id = secret_id; - is_set_secret_id = true; + if (is_set_auth_port) { + length += sprintf(conn_in + length, "IAM_PORT=%d;", m_auth_port); } -}; - -class ConnectionStringBuilder { - public: - ConnectionStringBuilder() { - connection_string.reset(new ConnectionString()); + if (is_set_auth_expiration) { + length += sprintf(conn_in + length, "IAM_EXPIRATION_TIME=%d;", m_auth_expiration); } - - ConnectionStringBuilder& withDSN(const std::string& dsn) { - connection_string->set_dsn(dsn); - return *this; + if (is_set_secret_id) { + length += sprintf(conn_in + length, "SECRET_ID=%s;", m_secret_id.c_str()); } - - ConnectionStringBuilder& withServer(const std::string& server) { - connection_string->set_server(server); - return *this; + if (is_set_enable_custom_endpoint_monitoring) { + length += sprintf(conn_in + length, "ENABLE_CUSTOM_ENDPOINT_MONITORING=%d;", m_enable_custom_endpoint_monitoring ? 1 : 0); } - - ConnectionStringBuilder& withPort(const int& port) { - connection_string->set_port(port); - return *this; + if (is_set_custom_endpoint_region) { + length += sprintf(conn_in + length, "CUSTOM_ENDPOINT_REGION=%s;", m_custom_endpoint_region.c_str()); } - - ConnectionStringBuilder& withUID(const std::string& uid) { - connection_string->set_uid(uid); - return *this; + if (is_set_should_wait_for_info) { + length += sprintf(conn_in + length, "WAIT_FOR_CUSTOM_ENDPOINT_INFO=%d;", m_should_wait_for_info ? 1 : 0); } - - ConnectionStringBuilder& withPWD(const std::string& pwd) { - connection_string->set_pwd(pwd); - return *this; + if (is_set_custom_endpoint_info_refresh_rate_ms) { + length += sprintf(conn_in + length, "CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS=%d;", m_custom_endpoint_info_refresh_rate_ms); } - - ConnectionStringBuilder& withDatabase(const std::string& db) { - connection_string->set_db(db); - return *this; + if (is_set_wait_on_cached_info_duration_ms) { + length += sprintf(conn_in + length, "WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS=%ld;", m_wait_on_cached_info_duration_ms); } - - ConnectionStringBuilder& withLogQuery(const bool& log_query) { - connection_string->set_log_query(log_query); - return *this; + if (is_set_idle_monitor_expiration_ms) { + length += sprintf(conn_in + length, "CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS=%ld;", m_idle_monitor_expiration_ms); } + snprintf(conn_in + length, sizeof(conn_in) - length, "\0"); - ConnectionStringBuilder& withFailoverMode(const std::string& failover_mode) { - connection_string->set_failover_mode(failover_mode); - return *this; - } + std::string connection_string(conn_in); + return connection_string; + } - ConnectionStringBuilder& withMultiStatements(const bool& multi_statements) { - connection_string->set_multi_statements(multi_statements); - return *this; - } + private: + // Required fields + std::string m_dsn, m_server; + int m_port; - ConnectionStringBuilder& withEnableClusterFailover(const bool& enable_cluster_failover) { - connection_string->set_enable_cluster_failover(enable_cluster_failover); - return *this; - } - - ConnectionStringBuilder& withFailoverTimeout(const int& failover_t) { - connection_string->set_failover_timeout(failover_t); - return *this; - } - - ConnectionStringBuilder& withConnectTimeout(const int& connect_timeout) { - connection_string->set_connect_timeout(connect_timeout); - return *this; - } - - ConnectionStringBuilder& withNetworkTimeout(const int& network_timeout) { - connection_string->set_network_timeout(network_timeout); - return *this; - } - - ConnectionStringBuilder& withHostPattern(const std::string& host_pattern) { - connection_string->set_host_pattern(host_pattern); - return *this; - } - - ConnectionStringBuilder& withEnableFailureDetection(const bool& enable_failure_detection) { - connection_string->set_enable_failure_detection(enable_failure_detection); - return *this; - } - - ConnectionStringBuilder& withFailureDetectionTime(const int& failure_detection_time) { - connection_string->set_failure_detection_time(failure_detection_time); - return *this; - } - - ConnectionStringBuilder& withFailureDetectionTimeout(const int& failure_detection_timeout) { - connection_string->set_failure_detection_timeout(failure_detection_timeout); - return *this; - } - - ConnectionStringBuilder& withFailureDetectionInterval(const int& failure_detection_interval) { - connection_string->set_failure_detection_interval(failure_detection_interval); - return *this; - } - - ConnectionStringBuilder& withFailureDetectionCount(const int& failure_detection_count) { - connection_string->set_failure_detection_count(failure_detection_count); - return *this; - } - - ConnectionStringBuilder& withMonitorDisposalTime(const int& monitor_disposal_time) { - connection_string->set_monitor_disposal_time(monitor_disposal_time); - return *this; - } - - ConnectionStringBuilder& withReadTimeout(const int& read_timeout) { - connection_string->set_read_timeout(read_timeout); - return *this; - } - - ConnectionStringBuilder& withWriteTimeout(const int& write_timeout) { - connection_string->set_write_timeout(write_timeout); - return *this; - } - - ConnectionStringBuilder& withAuthMode(const std::string& auth_mode) { - connection_string->set_auth_mode(auth_mode); - return *this; - } - - ConnectionStringBuilder& withAuthRegion(const std::string& auth_region) { - connection_string->set_auth_region(auth_region); - return *this; - } - - ConnectionStringBuilder& withAuthHost(const std::string& auth_host) { - connection_string->set_auth_host(auth_host); - return *this; - } - - ConnectionStringBuilder& withAuthPort(const int& auth_port) { - connection_string->set_auth_port(auth_port); - return *this; - } - - ConnectionStringBuilder& withAuthExpiration(const int& auth_expiration) { - connection_string->set_auth_expiration(auth_expiration); - return *this; - } - - ConnectionStringBuilder& withSecretId(const std::string& secret_id) { - connection_string->set_secret_id(secret_id); - return *this; - } + // Optional fields + std::string m_uid, m_pwd, m_db; + bool m_log_query, m_multi_statements, m_enable_cluster_failover; + int m_failover_timeout, m_connect_timeout, m_network_timeout; + std::string m_host_pattern, m_failover_mode; + bool m_enable_failure_detection; + int m_failure_detection_time, m_failure_detection_timeout, m_failure_detection_interval, m_failure_detection_count, + m_monitor_disposal_time, m_read_timeout, m_write_timeout; + std::string m_auth_mode, m_auth_region, m_auth_host, m_secret_id, m_custom_endpoint_region; + int m_auth_port, m_auth_expiration; + bool m_enable_custom_endpoint_monitoring, m_should_wait_for_info; + long m_custom_endpoint_info_refresh_rate_ms, m_wait_on_cached_info_duration_ms, m_idle_monitor_expiration_ms; + + bool is_set_uid, is_set_pwd, is_set_db; + bool is_set_log_query, is_set_failover_mode, is_set_multi_statements; + bool is_set_enable_cluster_failover; + bool is_set_failover_timeout, is_set_connect_timeout, is_set_network_timeout; + bool is_set_host_pattern; + bool is_set_enable_failure_detection; + bool is_set_failure_detection_time, is_set_failure_detection_timeout, is_set_failure_detection_interval, + is_set_failure_detection_count; + bool is_set_monitor_disposal_time; + bool is_set_read_timeout, is_set_write_timeout; + bool is_set_auth_mode, is_set_auth_region, is_set_auth_host, is_set_auth_port, is_set_auth_expiration, + is_set_secret_id; + bool is_set_enable_custom_endpoint_monitoring, is_set_should_wait_for_info, is_set_custom_endpoint_region; + bool is_set_custom_endpoint_info_refresh_rate_ms, is_set_wait_on_cached_info_duration_ms, is_set_idle_monitor_expiration_ms; + + void set_dsn(const std::string& dsn) { m_dsn = dsn; } + + void set_server(const std::string& server) { m_server = server; } + + void set_port(const int& port) { m_port = port; } + + void set_uid(const std::string& uid) { + m_uid = uid; + is_set_uid = true; + } + + void set_pwd(const std::string& pwd) { + m_pwd = pwd; + is_set_pwd = true; + } + + void set_db(const std::string& db) { + m_db = db; + is_set_db = true; + } + + void set_log_query(const bool& log_query) { + m_log_query = log_query; + is_set_log_query = true; + } + + void set_failover_mode(const std::string& failover_mode) { + m_failover_mode = failover_mode; + is_set_failover_mode = true; + } + + void set_multi_statements(const bool& multi_statements) { + m_multi_statements = multi_statements; + is_set_multi_statements = true; + } + + void set_enable_cluster_failover(const bool& enable_cluster_failover) { + m_enable_cluster_failover = enable_cluster_failover; + is_set_enable_cluster_failover = true; + } + + void set_failover_timeout(const int& failover_timeout) { + m_failover_timeout = failover_timeout; + is_set_failover_timeout = true; + } + + void set_connect_timeout(const int& connect_timeout) { + m_connect_timeout = connect_timeout; + is_set_connect_timeout = true; + } + + void set_network_timeout(const int& network_timeout) { + m_network_timeout = network_timeout; + is_set_network_timeout = true; + } + + void set_host_pattern(const std::string& host_pattern) { + m_host_pattern = host_pattern; + is_set_host_pattern = true; + } + + void set_enable_failure_detection(const bool& enable_failure_detection) { + m_enable_failure_detection = enable_failure_detection; + is_set_enable_failure_detection = true; + } + + void set_failure_detection_time(const int& failure_detection_time) { + m_failure_detection_time = failure_detection_time; + is_set_failure_detection_time = true; + } + + void set_failure_detection_timeout(const int& failure_detection_timeout) { + m_failure_detection_timeout = failure_detection_timeout; + is_set_failure_detection_timeout = true; + } + + void set_failure_detection_interval(const int& failure_detection_interval) { + m_failure_detection_interval = failure_detection_interval; + is_set_failure_detection_interval = true; + } + + void set_failure_detection_count(const int& failure_detection_count) { + m_failure_detection_count = failure_detection_count; + is_set_failure_detection_count = true; + } + + void set_monitor_disposal_time(const int& monitor_disposal_time) { + m_monitor_disposal_time = monitor_disposal_time; + is_set_monitor_disposal_time = true; + } + + void set_read_timeout(const int& read_timeout) { + m_read_timeout = read_timeout; + is_set_read_timeout = true; + } + + void set_write_timeout(const int& write_timeout) { + m_write_timeout = write_timeout; + is_set_write_timeout = true; + } + + void set_auth_mode(const std::string& auth_mode) { + m_auth_mode = auth_mode; + is_set_auth_mode = true; + } + + void set_auth_region(const std::string& auth_region) { + m_auth_region = auth_region; + is_set_auth_region = true; + } + + void set_auth_host(const std::string& auth_host) { + m_auth_host = auth_host; + is_set_auth_host = true; + } + + void set_auth_port(const int& auth_port) { + m_auth_port = auth_port; + is_set_auth_port = true; + } + + void set_auth_expiration(const int& auth_expiration) { + m_auth_expiration = auth_expiration; + is_set_auth_expiration = true; + } + + void set_secret_id(const std::string& secret_id) { + m_secret_id = secret_id; + is_set_secret_id = true; + } + + void set_enable_custom_endpoint_monitoring(const bool& enable_custom_endpoint_monitoring) { + m_enable_custom_endpoint_monitoring = enable_custom_endpoint_monitoring; + is_set_enable_custom_endpoint_monitoring = true; + } + + void set_custom_endpoint_monitoring_region(const std::string& region) { + m_custom_endpoint_region = region; + is_set_custom_endpoint_region = true; + } + + void set_should_wait_for_info(const bool& wait_for_info) { + m_should_wait_for_info = wait_for_info; + is_set_should_wait_for_info = true; + } + + void set_custom_endpoint_info_refresh_rate_ms(const long& custom_endpoint_info_refresh_rate_ms) { + m_custom_endpoint_info_refresh_rate_ms = custom_endpoint_info_refresh_rate_ms; + is_set_custom_endpoint_info_refresh_rate_ms = true; + } + + void set_wait_on_cached_info_duration_ms(const long& wait_on_cached_info_duration_ms) { + m_wait_on_cached_info_duration_ms = wait_on_cached_info_duration_ms; + is_set_wait_on_cached_info_duration_ms = true; + } + + void set_idle_monitor_expiration_ms(const long& idle_monitor_expiration_ms) { + m_idle_monitor_expiration_ms = idle_monitor_expiration_ms; + is_set_idle_monitor_expiration_ms = true; + } +}; - std::string build() const { - if (connection_string->m_dsn.empty()) { - throw std::runtime_error("DSN is a required field in a connection string."); - } - if (connection_string->m_server.empty()) { - throw std::runtime_error("Server is a required field in a connection string."); - } - if (connection_string->m_port < 1) { - throw std::runtime_error("Port is a required field in a connection string."); - } - return connection_string->get_connection_string(); - } - - private: - std::unique_ptr connection_string; +class ConnectionStringBuilder { + public: + ConnectionStringBuilder() { connection_string.reset(new ConnectionString()); } + + ConnectionStringBuilder& withDSN(const std::string& dsn) { + connection_string->set_dsn(dsn); + return *this; + } + + ConnectionStringBuilder& withServer(const std::string& server) { + connection_string->set_server(server); + return *this; + } + + ConnectionStringBuilder& withPort(const int& port) { + connection_string->set_port(port); + return *this; + } + + ConnectionStringBuilder& withUID(const std::string& uid) { + connection_string->set_uid(uid); + return *this; + } + + ConnectionStringBuilder& withPWD(const std::string& pwd) { + connection_string->set_pwd(pwd); + return *this; + } + + ConnectionStringBuilder& withDatabase(const std::string& db) { + connection_string->set_db(db); + return *this; + } + + ConnectionStringBuilder& withLogQuery(const bool& log_query) { + connection_string->set_log_query(log_query); + return *this; + } + + ConnectionStringBuilder& withFailoverMode(const std::string& failover_mode) { + connection_string->set_failover_mode(failover_mode); + return *this; + } + + ConnectionStringBuilder& withMultiStatements(const bool& multi_statements) { + connection_string->set_multi_statements(multi_statements); + return *this; + } + + ConnectionStringBuilder& withEnableClusterFailover(const bool& enable_cluster_failover) { + connection_string->set_enable_cluster_failover(enable_cluster_failover); + return *this; + } + + ConnectionStringBuilder& withFailoverTimeout(const int& failover_t) { + connection_string->set_failover_timeout(failover_t); + return *this; + } + + ConnectionStringBuilder& withConnectTimeout(const int& connect_timeout) { + connection_string->set_connect_timeout(connect_timeout); + return *this; + } + + ConnectionStringBuilder& withNetworkTimeout(const int& network_timeout) { + connection_string->set_network_timeout(network_timeout); + return *this; + } + + ConnectionStringBuilder& withHostPattern(const std::string& host_pattern) { + connection_string->set_host_pattern(host_pattern); + return *this; + } + + ConnectionStringBuilder& withEnableFailureDetection(const bool& enable_failure_detection) { + connection_string->set_enable_failure_detection(enable_failure_detection); + return *this; + } + + ConnectionStringBuilder& withFailureDetectionTime(const int& failure_detection_time) { + connection_string->set_failure_detection_time(failure_detection_time); + return *this; + } + + ConnectionStringBuilder& withFailureDetectionTimeout(const int& failure_detection_timeout) { + connection_string->set_failure_detection_timeout(failure_detection_timeout); + return *this; + } + + ConnectionStringBuilder& withFailureDetectionInterval(const int& failure_detection_interval) { + connection_string->set_failure_detection_interval(failure_detection_interval); + return *this; + } + + ConnectionStringBuilder& withFailureDetectionCount(const int& failure_detection_count) { + connection_string->set_failure_detection_count(failure_detection_count); + return *this; + } + + ConnectionStringBuilder& withMonitorDisposalTime(const int& monitor_disposal_time) { + connection_string->set_monitor_disposal_time(monitor_disposal_time); + return *this; + } + + ConnectionStringBuilder& withReadTimeout(const int& read_timeout) { + connection_string->set_read_timeout(read_timeout); + return *this; + } + + ConnectionStringBuilder& withWriteTimeout(const int& write_timeout) { + connection_string->set_write_timeout(write_timeout); + return *this; + } + + ConnectionStringBuilder& withAuthMode(const std::string& auth_mode) { + connection_string->set_auth_mode(auth_mode); + return *this; + } + + ConnectionStringBuilder& withAuthRegion(const std::string& auth_region) { + connection_string->set_auth_region(auth_region); + return *this; + } + + ConnectionStringBuilder& withAuthHost(const std::string& auth_host) { + connection_string->set_auth_host(auth_host); + return *this; + } + + ConnectionStringBuilder& withAuthPort(const int& auth_port) { + connection_string->set_auth_port(auth_port); + return *this; + } + + ConnectionStringBuilder& withAuthExpiration(const int& auth_expiration) { + connection_string->set_auth_expiration(auth_expiration); + return *this; + } + + ConnectionStringBuilder& withSecretId(const std::string& secret_id) { + connection_string->set_secret_id(secret_id); + return *this; + } + + ConnectionStringBuilder& withEnableCustomEndpointMonitoring(const bool& enable_custom_endpoint_monitoring) { + connection_string->set_enable_custom_endpoint_monitoring(enable_custom_endpoint_monitoring); + return *this; + } + ConnectionStringBuilder& withCustomEndpointRegion(const std::string& region) { + connection_string->set_custom_endpoint_monitoring_region(region); + return *this; + } + + ConnectionStringBuilder& withShouldWaitForInfo(const bool& should_wait_for_info) { + connection_string->set_should_wait_for_info(should_wait_for_info); + return *this; + } + + ConnectionStringBuilder& withCustomEndpointInfoRefreshRateMs(const long& custom_endpoint_info_refresh_rate_ms) { + connection_string->set_custom_endpoint_info_refresh_rate_ms(custom_endpoint_info_refresh_rate_ms); + return *this; + } + + ConnectionStringBuilder& withWaitOnCachedInfoDurationMs(const long& wait_on_cached_info_duration_ms) { + connection_string->set_wait_on_cached_info_duration_ms(wait_on_cached_info_duration_ms); + return *this; + } + + ConnectionStringBuilder& withIdleMonitorExpirationMs(const long& idle_monitor_expiration_ms) { + connection_string->set_idle_monitor_expiration_ms(idle_monitor_expiration_ms); + return *this; + } + + std::string build() const { + if (connection_string->m_dsn.empty()) { + throw std::runtime_error("DSN is a required field in a connection string."); + } + if (connection_string->m_server.empty()) { + throw std::runtime_error("Server is a required field in a connection string."); + } + if (connection_string->m_port < 1) { + throw std::runtime_error("Port is a required field in a connection string."); + } + return connection_string->get_connection_string(); + } + + private: + std::unique_ptr connection_string; }; #endif /* __CONNECTIONSTRINGBUILDER_H__ */ diff --git a/integration/custom_endpoint_integration_test.cc b/integration/custom_endpoint_integration_test.cc new file mode 100644 index 000000000..5e687c094 --- /dev/null +++ b/integration/custom_endpoint_integration_test.cc @@ -0,0 +1,178 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include "base_failover_integration_test.cc" +#include +#include +#include + +class CustomEndpointIntegrationTest : public BaseFailoverIntegrationTest { + protected: + std::string ACCESS_KEY = std::getenv("AWS_ACCESS_KEY_ID"); + std::string SECRET_ACCESS_KEY = std::getenv("AWS_SECRET_ACCESS_KEY"); + std::string SESSION_TOKEN = std::getenv("AWS_SESSION_TOKEN"); + std::string RDS_ENDPOINT = std::getenv("RDS_ENDPOINT"); + std::string RDS_REGION = std::getenv("RDS_REGION"); + std::string ENDPOINT_ID = + "test-endpoint-1-" + std::to_string(std::chrono::steady_clock::now().time_since_epoch().count()); + std::string region = "us-east-2"; + Aws::RDS::RDSClientConfiguration client_config; + Aws::RDS::RDSClient rds_client; + SQLHENV env = nullptr; + SQLHDBC dbc = nullptr; + + bool is_endpoint_created = false; + + static void SetUpTestSuite() { Aws::InitAPI(options); } + + static void TearDownTestSuite() { Aws::ShutdownAPI(options); } + void SetUp() override { + SQLAllocHandle(SQL_HANDLE_ENV, nullptr, &env); + SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, reinterpret_cast(SQL_OV_ODBC3), 0); + SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc); + + Aws::Auth::AWSCredentials credentials = + SESSION_TOKEN.empty() ? Aws::Auth::AWSCredentials(Aws::String(ACCESS_KEY), Aws::String(SECRET_ACCESS_KEY)) + : Aws::Auth::AWSCredentials(Aws::String(ACCESS_KEY), Aws::String(SECRET_ACCESS_KEY), + Aws::String(SESSION_TOKEN)); + if (!RDS_REGION.empty()) { + region = RDS_REGION; + } + client_config.region = region; + if (!RDS_ENDPOINT.empty()) { + client_config.endpointOverride = RDS_ENDPOINT; + } + rds_client = Aws::RDS::RDSClient(credentials, client_config); + + cluster_instances = retrieve_topology_via_SDK(rds_client, cluster_id); + writer_id = get_writer_id(cluster_instances); + writer_endpoint = get_endpoint(writer_id); + readers = get_readers(cluster_instances); + reader_id = get_first_reader_id(cluster_instances); + reader_endpoint = get_proxied_endpoint(reader_id); + target_writer_id = get_random_DB_cluster_reader_instance_id(readers); + + builder = ConnectionStringBuilder(); + builder.withPort(MYSQL_PORT).withLogQuery(true).withEnableFailureDetection(true); + + if (!is_endpoint_created) { + const std::vector writer{writer_id}; + create_custom_endpoint(cluster_id, writer); + wait_until_endpoint_available(cluster_id); + } + } + + void create_custom_endpoint(const std::string& cluster_id, const std::vector& writer) const { + Aws::RDS::Model::CreateDBClusterEndpointRequest rds_req; + rds_req.SetDBClusterEndpointIdentifier(ENDPOINT_ID); + rds_req.SetDBClusterIdentifier(cluster_id); + rds_req.SetEndpointType("ANY"); + rds_req.SetStaticMembers(writer); + rds_client.CreateDBClusterEndpoint(rds_req); + } + + void delete_custom_endpoint() { + Aws::RDS::Model::DeleteDBClusterEndpointRequest rds_req; + rds_req.SetDBClusterEndpointIdentifier(ENDPOINT_ID); + rds_client.DeleteDBClusterEndpoint(rds_req); + } + + /** + * Wait up to 5 minutes for the new custom endpoint to become unavailable. + */ + void wait_until_endpoint_available(const std::string& cluster_id) const { + const std::chrono::steady_clock::time_point end_time = std::chrono::steady_clock::now() + std::chrono::minutes(5); + bool is_available = false; + + Aws::String status = get_DB_cluster(rds_client, cluster_id).GetStatus(); + + while (std::chrono::steady_clock::now() < end_time) { + const auto endpoint_info = get_custom_endpoint_info(rds_client, ENDPOINT_ID); + is_available = endpoint_info.GetStatus() == "available"; + if (is_available) { + break; + } + std::this_thread::sleep_for(std::chrono::seconds(3)); + } + + if (!is_available) { + throw std::runtime_error( + "The test setup step timed out while waiting for the custom endpoint to become available: " + ENDPOINT_ID); + } + } + + void TearDown() override { + delete_custom_endpoint(); + + if (nullptr != dbc) { + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + } + if (nullptr != env) { + SQLFreeHandle(SQL_HANDLE_ENV, env); + } + } +}; + +TEST_F(CustomEndpointIntegrationTest, test_CustomeEndpointFailover) { + const auto endpoint_info = get_custom_endpoint_info(rds_client, ENDPOINT_ID); + + connection_string = builder.withDSN(dsn) + .withServer(endpoint_info.GetEndpoint()) + .withUID(user) + .withPWD(pwd) + .withDatabase(db) + .withFailoverMode("reader or writer") + .withEnableCustomEndpointMonitoring(true) + .withCustomEndpointRegion(region) + .build(); + SQLCHAR conn_out[4096] = "\0"; + SQLSMALLINT len; + EXPECT_EQ(SQL_SUCCESS, SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS, conn_out, + MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT)); + + const std::vector endpoint_members = endpoint_info.GetStaticMembers(); + for (const auto& member : endpoint_members) { + std::cout << "static members: " << member << std::endl; + } + const std::string current_connection_id = query_instance_id(dbc); + std::cout << "current connection id: " << current_connection_id << std::endl; + EXPECT_NE(std::find(endpoint_members.begin(), endpoint_members.end(), current_connection_id), endpoint_members.end()); + + failover_cluster_and_wait_until_writer_changed( + rds_client, cluster_id, writer_id, current_connection_id == writer_id ? target_writer_id : current_connection_id); + + assert_query_failed(dbc, SERVER_ID_QUERY, ERROR_COMM_LINK_CHANGED); + + const std::string new_connection_id = query_instance_id(dbc); + std::cout << "new connection id: " << new_connection_id << std::endl; + + EXPECT_NE(std::find(endpoint_members.begin(), endpoint_members.end(), new_connection_id), endpoint_members.end()); + + EXPECT_EQ(SQL_SUCCESS, SQLDisconnect(dbc)); +} diff --git a/unit_testing/CMakeLists.txt b/unit_testing/CMakeLists.txt index 1d37f23fe..c8184dc78 100644 --- a/unit_testing/CMakeLists.txt +++ b/unit_testing/CMakeLists.txt @@ -57,6 +57,8 @@ add_executable( adfs_proxy_test.cc cluster_aware_metrics_test.cc + custom_endpoint_monitor_test.cc + custom_endpoint_proxy_test.cc efm_proxy_test.cc iam_proxy_test.cc failover_handler_test.cc diff --git a/unit_testing/custom_endpoint_monitor_test.cc b/unit_testing/custom_endpoint_monitor_test.cc new file mode 100644 index 000000000..3ee3c94bb --- /dev/null +++ b/unit_testing/custom_endpoint_monitor_test.cc @@ -0,0 +1,109 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include +#include + +#include + +#include +#include + +#include "driver/custom_endpoint_monitor.h" +#include "test_utils.h" +#include "mock_objects.h" + +using namespace Aws::RDS; + +using ::testing::_; +using ::testing::Return; +using ::testing::StrEq; + +namespace { +const std::string WRITER_CLUSTER_URL{"writer.cluster-XYZ.us-east-1.rds.amazonaws.com"}; +const std::string CUSTOM_ENDPOINT_URL{"custom.cluster-custom-XYZ.us-east-1.rds.amazonaws.com"}; +const auto ENDPOINT = Aws::Utils::Json::JsonValue(CUSTOM_ENDPOINT_URL); + +const long long REFRESH_RATE_NANOS = 50000000; +} // namespace + +static SQLHENV env; +static Aws::SDKOptions sdk_options; + +class CustomEndpointMonitorTest : public testing::Test { + protected: + DBC* dbc; + DataSource* ds; + MOCK_CONNECTION_PROXY* mock_connection_proxy; + std::shared_ptr mock_rds_client; + std::shared_ptr mock_topology_service; + + static void SetUpTestSuite() { + Aws::InitAPI(sdk_options); + } + + static void TearDownTestSuite() { + Aws::ShutdownAPI(sdk_options); + mysql_library_end(); + } + + void SetUp() override { + allocate_odbc_handles(env, dbc, ds); + mock_rds_client = std::make_shared(); + mock_connection_proxy = new MOCK_CONNECTION_PROXY(dbc, ds); + } + + void TearDown() override { + cleanup_odbc_handles(env, dbc, ds); + TEST_UTILS::get_custom_endpoint_cache().clear(); + delete mock_connection_proxy; + } +}; + +TEST_F(CustomEndpointMonitorTest, TestRun) { + Model::DBClusterEndpoint endpoint; + endpoint.AddStaticMembers(CUSTOM_ENDPOINT_URL); + std::vector endpoints{endpoint}; + + const auto expected_result = Model::DescribeDBClusterEndpointsResult().WithDBClusterEndpoints(endpoints); + const auto expected_outcome = Model::DescribeDBClusterEndpointsOutcome(expected_result); + + EXPECT_CALL(*mock_rds_client, DescribeDBClusterEndpoints(Property( + "GetDBClusterEndpointIdentifier", + &Model::DescribeDBClusterEndpointsRequest::GetDBClusterEndpointIdentifier, + StrEq("custom")))) + .WillRepeatedly(Return(expected_outcome)); + + CUSTOM_ENDPOINT_MONITOR monitor(mock_topology_service, CUSTOM_ENDPOINT_URL, "custom", "us-east-1", REFRESH_RATE_NANOS, + true, + mock_rds_client); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + monitor.stop(); +} diff --git a/unit_testing/custom_endpoint_proxy_test.cc b/unit_testing/custom_endpoint_proxy_test.cc new file mode 100644 index 000000000..41d7dc75b --- /dev/null +++ b/unit_testing/custom_endpoint_proxy_test.cc @@ -0,0 +1,104 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License, version 2.0 +// (GPLv2), as published by the Free Software Foundation, with the +// following additional permissions: +// +// This program is distributed with certain software that is licensed +// under separate terms, as designated in a particular file or component +// or in the license documentation. Without limiting your rights under +// the GPLv2, the authors of this program hereby grant you an additional +// permission to link the program and your derivative works with the +// separately licensed software that they have included with the program. +// +// Without limiting the foregoing grant of rights under the GPLv2 and +// additional permission as to separately licensed software, this +// program is also subject to the Universal FOSS Exception, version 1.0, +// a copy of which can be found along with its FAQ at +// http://oss.oracle.com/licenses/universal-foss-exception. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +// See the GNU General Public License, version 2.0, for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see +// http://www.gnu.org/licenses/gpl-2.0.html. + +#include +#include +#include + +#include "test_utils.h" +#include "mock_objects.h" + +using ::testing::_; +using ::testing::Return; +using ::testing::StrEq; + +namespace { +const std::string WRITER_CLUSTER_URL{"writer.cluster-XYZ.us-east-1.rds.amazonaws.com"}; +const std::string CUSTOM_ENDPOINT_URL{"custom.cluster-custom-XYZ.us-east-1.rds.amazonaws.com"}; +} // namespace + +static SQLHENV env; +static Aws::SDKOptions options; + +class CustomEndpointProxyTest : public testing::Test { + protected: + DBC* dbc; + DataSource* ds; + MOCK_CONNECTION_PROXY* mock_connection_proxy; + std::shared_ptr mock_monitor = std::make_shared(); + + static void SetUpTestSuite() {} + + static void TearDownTestSuite() { mysql_library_end(); } + + void SetUp() override { + allocate_odbc_handles(env, dbc, ds); + ds->opt_ENABLE_CUSTOM_ENDPOINT_MONITORING = true; + ds->opt_WAIT_FOR_CUSTOM_ENDPOINT_INFO = true; + ds->opt_CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS = 60000; + + mock_connection_proxy = new MOCK_CONNECTION_PROXY(dbc, ds); + } + + void TearDown() override { + TEST_UTILS::get_custom_endpoint_monitor_cache().release_resources(); + cleanup_odbc_handles(env, dbc, ds); + } +}; + +TEST_F(CustomEndpointProxyTest, TestConnect_MonitorNotCreatedIfNotCustomEndpointHost) { + TEST_CUSTOM_ENDPOINT_PROXY custom_endpoint_proxy(dbc, ds, mock_connection_proxy); + EXPECT_EQ(TEST_UTILS::get_custom_endpoint_monitor_cache().size(), 0); + custom_endpoint_proxy.connect(WRITER_CLUSTER_URL.c_str(), "", "", "", 3306, "", 0); + + EXPECT_EQ(TEST_UTILS::get_custom_endpoint_monitor_cache().size(), 0); + EXPECT_CALL(*mock_connection_proxy, connect(_, _, _, _, _, _, _)).Times(0); +} + +TEST_F(CustomEndpointProxyTest, TestConnect_MonitorCreated) { + TEST_CUSTOM_ENDPOINT_PROXY custom_endpoint_proxy(dbc, ds, mock_connection_proxy); + EXPECT_EQ(0, TEST_UTILS::get_custom_endpoint_monitor_cache().size()); + EXPECT_CALL(custom_endpoint_proxy, create_custom_endpoint_monitor(_)).WillOnce(Return(mock_monitor)); + EXPECT_CALL(*mock_connection_proxy, connect(_, _, _, _, _, _, _)).Times(1); + custom_endpoint_proxy.connect(CUSTOM_ENDPOINT_URL.c_str(), "", "", "", 3306, "", 0); + + EXPECT_EQ(TEST_UTILS::get_custom_endpoint_monitor_cache().size(), 1); +} + +TEST_F(CustomEndpointProxyTest, TestConnect_TimeoutWaitingForInfo) { + TEST_CUSTOM_ENDPOINT_PROXY custom_endpoint_proxy(dbc, ds, mock_connection_proxy); + ds->opt_WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS = 100; + EXPECT_EQ(TEST_UTILS::get_custom_endpoint_monitor_cache().size(), 0); + EXPECT_CALL(custom_endpoint_proxy, create_custom_endpoint_monitor(_)).WillOnce(Return(mock_monitor)); + ds->opt_WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS = 1; + custom_endpoint_proxy.connect(CUSTOM_ENDPOINT_URL.c_str(), "user", "pwd", "db", 3306, "", 0); + + EXPECT_EQ(TEST_UTILS::get_custom_endpoint_monitor_cache().size(), 1); + EXPECT_CALL(*mock_connection_proxy, connect(_, _, _, _, _, _, _)).Times(0); +} diff --git a/unit_testing/failover_handler_test.cc b/unit_testing/failover_handler_test.cc index f86f1ce2b..62d033a41 100644 --- a/unit_testing/failover_handler_test.cc +++ b/unit_testing/failover_handler_test.cc @@ -43,6 +43,7 @@ using ::testing::StrEq; namespace { const std::string US_EAST_REGION_CLUSTER = "database-test-name.cluster-XYZ.us-east-2.rds.amazonaws.com"; const std::string US_EAST_REGION_CLUSTER_READ_ONLY = "database-test-name.cluster-ro-XYZ.us-east-2.rds.amazonaws.com"; + const std::string US_EAST_REGION_INSTANCE = "instance-test-name.XYZ.us-east-2.rds.amazonaws.com"; const std::string US_EAST_REGION_PROXY = "proxy-test-name.proxy-XYZ.us-east-2.rds.amazonaws.com"; const std::string US_EAST_REGION_CUSTON_DOMAIN = "custom-test-name.cluster-custom-XYZ.us-east-2.rds.amazonaws.com"; @@ -424,6 +425,26 @@ TEST_F(FailoverHandlerTest, GetRdsClusterHostUrl) { EXPECT_EQ(std::string(), TEST_UTILS::get_rds_cluster_host_url(CHINA_REGION_CUSTON_DOMAIN)); } +TEST_F(FailoverHandlerTest, GetRdsClusterId) { + EXPECT_EQ("database-test-name", TEST_UTILS::get_rds_cluster_id(US_EAST_REGION_CLUSTER)); + EXPECT_EQ("database-test-name", TEST_UTILS::get_rds_cluster_id(US_EAST_REGION_CLUSTER_READ_ONLY)); + EXPECT_EQ(std::string(), TEST_UTILS::get_rds_cluster_id(US_EAST_REGION_INSTANCE)); + + EXPECT_EQ("proxy-test-name", TEST_UTILS::get_rds_cluster_id(US_EAST_REGION_PROXY)); + EXPECT_EQ("custom-test-name", TEST_UTILS::get_rds_cluster_id(US_EAST_REGION_CUSTON_DOMAIN)); + + EXPECT_EQ("database-test-name", TEST_UTILS::get_rds_cluster_id(CHINA_REGION_CLUSTER)); + EXPECT_EQ("database-test-name", TEST_UTILS::get_rds_cluster_id(CHINA_REGION_CLUSTER_READ_ONLY)); + EXPECT_EQ("proxy-test-name", TEST_UTILS::get_rds_cluster_id(CHINA_REGION_PROXY)); + EXPECT_EQ("custom-test-name", TEST_UTILS::get_rds_cluster_id(CHINA_REGION_CUSTON_DOMAIN)); +} + +TEST_F(FailoverHandlerTest, GetRdsInstanceId) { + EXPECT_EQ("database-test-name", TEST_UTILS::get_rds_instance_id(US_EAST_REGION_INSTANCE)); + EXPECT_EQ(std::string(), TEST_UTILS::get_rds_instance_id(US_EAST_REGION_CLUSTER_READ_ONLY)); + EXPECT_EQ(std::string(), TEST_UTILS::get_rds_instance_id(US_EAST_REGION_CLUSTER)); +} + TEST_F(FailoverHandlerTest, ConnectToNewWriter) { std::string server = "my-cluster-name.cluster-XYZ.us-east-2.rds.amazonaws.com"; ds->opt_SERVER.set_remove_brackets((SQLWCHAR*)to_sqlwchar_string(server).c_str(), server.size()); diff --git a/unit_testing/mock_objects.h b/unit_testing/mock_objects.h index 80ebce772..f684ec3d9 100644 --- a/unit_testing/mock_objects.h +++ b/unit_testing/mock_objects.h @@ -35,6 +35,7 @@ #include #include "driver/connection_proxy.h" +#include "driver/custom_endpoint_proxy.h" #include "driver/failover.h" #include "driver/saml_http_client.h" #include "driver/monitor_thread_container.h" @@ -222,6 +223,13 @@ class MOCK_SECRETS_MANAGER_CLIENT : public Aws::SecretsManager::SecretsManagerCl MOCK_METHOD(Aws::SecretsManager::Model::GetSecretValueOutcome, GetSecretValue, (const Aws::SecretsManager::Model::GetSecretValueRequest&), (const)); }; +class MOCK_RDS_CLIENT : public Aws::RDS::RDSClient { +public: + MOCK_RDS_CLIENT() : RDSClient(){}; + + MOCK_METHOD(Aws::RDS::Model::DescribeDBClusterEndpointsOutcome, DescribeDBClusterEndpoints, (const Aws::RDS::Model::DescribeDBClusterEndpointsRequest&), (const)); +}; + class MOCK_AUTH_UTIL : public AUTH_UTIL { public: MOCK_AUTH_UTIL() : AUTH_UTIL() {}; @@ -234,4 +242,16 @@ class MOCK_SAML_HTTP_CLIENT : public SAML_HTTP_CLIENT { MOCK_METHOD(nlohmann::json, post, (const std::string&, const std::string&, const std::string&)); MOCK_METHOD(nlohmann::json, get, (const std::string&, const httplib::Headers&)); }; + +class TEST_CUSTOM_ENDPOINT_PROXY : public CUSTOM_ENDPOINT_PROXY { +public: + TEST_CUSTOM_ENDPOINT_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy) : CUSTOM_ENDPOINT_PROXY(dbc, ds, next_proxy) {}; + MOCK_METHOD(std::shared_ptr, create_custom_endpoint_monitor, (const long long refresh_rate_nanos), (override)); + static int get_monitor_size() { return monitors.size(); } +}; + +class MOCK_CUSTOM_ENDPOINT_MONITOR : public CUSTOM_ENDPOINT_MONITOR { + public: + MOCK_CUSTOM_ENDPOINT_MONITOR() {}; +}; #endif /* __MOCKOBJECTS_H__ */ diff --git a/unit_testing/sliding_expiration_cache_test.cc b/unit_testing/sliding_expiration_cache_test.cc index b623c4b8a..e9364a613 100644 --- a/unit_testing/sliding_expiration_cache_test.cc +++ b/unit_testing/sliding_expiration_cache_test.cc @@ -145,7 +145,7 @@ TEST_F(SlidingExpirationCacheTest, ExpirationTimeUpdateGet) { TEST_F(SlidingExpirationCacheTest, GetCacheExpireThread) { SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD cache(cache_exp_short); - + cache.init_clean_up_thread(); EXPECT_EQ(0, cache.size()); cache.put(cache_key_a, cache_val_a, cache_exp_short); cache.put(cache_key_b, cache_val_b, cache_exp_long); diff --git a/unit_testing/test_utils.cc b/unit_testing/test_utils.cc index 3b96eb9bf..7a01cb977 100644 --- a/unit_testing/test_utils.cc +++ b/unit_testing/test_utils.cc @@ -29,6 +29,9 @@ #include "test_utils.h" +#include "driver/custom_endpoint_monitor.h" +#include "driver/custom_endpoint_proxy.h" + void allocate_odbc_handles(SQLHENV& env, DBC*& dbc, DataSource*& ds) { SQLHDBC hdbc = nullptr; @@ -155,6 +158,23 @@ std::string TEST_UTILS::get_rds_cluster_host_url(std::string host) { return RDS_UTILS::get_rds_cluster_host_url(host); } +std::string TEST_UTILS::get_rds_cluster_id(std::string host) { + return RDS_UTILS::get_rds_cluster_id(host); +} + +std::string TEST_UTILS::get_rds_instance_id(std::string host) { + return RDS_UTILS::get_rds_instance_id(host); +} + std::string TEST_UTILS::get_rds_instance_host_pattern(std::string host) { return RDS_UTILS::get_rds_instance_host_pattern(host); } + +CACHE_MAP>& TEST_UTILS::get_custom_endpoint_cache() { + return std::ref(CUSTOM_ENDPOINT_MONITOR::custom_endpoint_cache); +} + +SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD>& +TEST_UTILS::get_custom_endpoint_monitor_cache() { + return std::ref(CUSTOM_ENDPOINT_PROXY::monitors); +} diff --git a/unit_testing/test_utils.h b/unit_testing/test_utils.h index d7ce4f84f..f431ce8fc 100644 --- a/unit_testing/test_utils.h +++ b/unit_testing/test_utils.h @@ -30,46 +30,54 @@ #ifndef __TESTUTILS_H__ #define __TESTUTILS_H__ -#include "driver/auth_util.h" +#include "driver/auth_util.h" +#include "driver/cache_map.h" +#include "driver/custom_endpoint_info.h" +#include "driver/custom_endpoint_monitor.h" #include "driver/driver.h" #include "driver/failover.h" #include "driver/iam_proxy.h" -#include "driver/okta_proxy.h" #include "driver/monitor.h" #include "driver/monitor_thread_container.h" -#include "driver/secrets_manager_proxy.h" +#include "driver/okta_proxy.h" #include "driver/rds_utils.h" +#include "driver/secrets_manager_proxy.h" +#include "driver/sliding_expiration_cache_with_clean_up_thread.h" void allocate_odbc_handles(SQLHENV& env, DBC*& dbc, DataSource*& ds); void cleanup_odbc_handles(SQLHENV env, DBC*& dbc, DataSource*& ds, bool call_myodbc_end = false); class TEST_UTILS { -public: - static std::chrono::milliseconds get_connection_check_interval(std::shared_ptr monitor); - static CONNECTION_STATUS check_connection_status(std::shared_ptr monitor); - static void populate_monitor_map(std::shared_ptr container, - std::set node_keys, std::shared_ptr monitor); - static void populate_task_map(std::shared_ptr container, - std::shared_ptr monitor); - static bool has_monitor(std::shared_ptr container, std::string node_key); - static bool has_task(std::shared_ptr container, std::shared_ptr monitor); - static bool has_any_tasks(std::shared_ptr container); - static bool has_available_monitor(std::shared_ptr container); - static std::shared_ptr get_available_monitor(std::shared_ptr container); - static size_t get_map_size(std::shared_ptr container); - static std::list> get_contexts(std::shared_ptr monitor); - static std::string build_cache_key(const char* host, const char* region, unsigned int port, const char* user); - static bool token_cache_contains_key(std::unordered_map token_cache, std::string cache_key); - static std::map, Aws::Utils::Json::JsonValue>& get_secrets_cache(); - static bool try_parse_region_from_secret(std::string secret, std::string& region); - static bool is_dns_pattern_valid(std::string host); - static bool is_rds_dns(std::string host); - static bool is_rds_cluster_dns(std::string host); - static bool is_rds_proxy_dns(std::string host); - static bool is_rds_writer_cluster_dns(std::string host); - static bool is_rds_custom_cluster_dns(std::string host); - static std::string get_rds_cluster_host_url(std::string host); - static std::string get_rds_instance_host_pattern(std::string host); + public: + static std::chrono::milliseconds get_connection_check_interval(std::shared_ptr monitor); + static CONNECTION_STATUS check_connection_status(std::shared_ptr monitor); + static void populate_monitor_map(std::shared_ptr container, std::set node_keys, + std::shared_ptr monitor); + static void populate_task_map(std::shared_ptr container, std::shared_ptr monitor); + static bool has_monitor(std::shared_ptr container, std::string node_key); + static bool has_task(std::shared_ptr container, std::shared_ptr monitor); + static bool has_any_tasks(std::shared_ptr container); + static bool has_available_monitor(std::shared_ptr container); + static std::shared_ptr get_available_monitor(std::shared_ptr container); + static size_t get_map_size(std::shared_ptr container); + static std::list> get_contexts(std::shared_ptr monitor); + static std::string build_cache_key(const char* host, const char* region, unsigned int port, const char* user); + static bool token_cache_contains_key(std::unordered_map token_cache, std::string cache_key); + static std::map, Aws::Utils::Json::JsonValue>& get_secrets_cache(); + static bool try_parse_region_from_secret(std::string secret, std::string& region); + static bool is_dns_pattern_valid(std::string host); + static bool is_rds_dns(std::string host); + static bool is_rds_cluster_dns(std::string host); + static bool is_rds_proxy_dns(std::string host); + static bool is_rds_writer_cluster_dns(std::string host); + static bool is_rds_custom_cluster_dns(std::string host); + static std::string get_rds_cluster_host_url(std::string host); + static std::string get_rds_cluster_id(std::string host); + static std::string get_rds_instance_id(std::string host); + static std::string get_rds_instance_host_pattern(std::string host); + static CACHE_MAP>& get_custom_endpoint_cache(); + static SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD>& + get_custom_endpoint_monitor_cache(); }; #endif /* __TESTUTILS_H__ */