Skip to content

Commit

Permalink
test: custom endpoints unit tests and integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
karenc-bq committed Jan 23, 2025
1 parent 07f6b7d commit 290d4d7
Show file tree
Hide file tree
Showing 39 changed files with 1,424 additions and 697 deletions.
1 change: 1 addition & 0 deletions .github/workflows/failover.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name: Failover Unit Tests

on:
workflow_dispatch:
push:
branches:
- main
Expand Down
24 changes: 24 additions & 0 deletions docs/using-the-aws-driver/CustomEndpoint.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Custom Endpoint Support

The Custom Endpoint support allows client application to use the driver with RDS custom endpoints. When the Custom Endpoint feature is enabled, the driver will analyse custom endpoint information to ensure instances used in connections are part of the custom endpoint being used. This includes connections used in failover.

## How to use the Driver with Custom Endpoint

### Enabling the Custom Endpoint Feature

1. If needed, create a custom endpoint using the AWS RDS Console:
- If needed, review the documentation about [creating a custom endpoint](https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/aurora-custom-endpoint-creating.html).
2. Set `ENABLE_CUSTOM_ENDPOINT_MONITORING` to `TRUE` to enable custom endpoint support.
3. If you are using the failover plugin, set the failover parameter `FAILOVER_MODE` according to the custom endpoint type. For example, if the custom endpoint you are using is of type `READER`, you can set `FAILOVER_MODE` to `strict-reader`, or if it is of type `ANY`, you can set `FAILOVER_MODE` to `reader-or-writer`.
4. Specify parameters that are required or specific to your case.

### Custom Endpoint Plugin Parameters

| Parameter | Value | Required | Description | Default Value | Example Value |
| ------------------------------------------ | :----: | :------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------- | ------------- |
| `ENABLE_CUSTOM_ENDPOINT_MONITORING` | bool | No | Set to TRUE to enable custom endpoint support. | `FALSE` | `TRUE` |
| `CUSTOM_ENDPOINT_REGION` | string | No | The region of the cluster's custom endpoints. If not specified, the region will be parsed from the URL. | `N/A` | `us-west-1` |
| `CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS` | long | No | Controls how frequently custom endpoint monitors fetch custom endpoint info, in milliseconds. | `30000` | `20000` |
| `CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS` | long | No | Controls how long a monitor should run without use before expiring and being removed, in milliseconds. | `900000` (15 minutes) | `600000` |
| `WAIT_FOR_CUSTOM_ENDPOINT_INFO` | bool | No | Controls whether to wait for custom endpoint info to become available before connecting or executing a method. Waiting is only necessary if a connection to a given custom endpoint has not been opened or used recently. Note that disabling this may result in occasional connections to instances outside of the custom endpoint. | `true` | `true` |
| `WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS` | long | No | Controls the maximum amount of time that the plugin will wait for custom endpoint info to be made available by the custom endpoint monitor, in milliseconds. | `5000` | `7000` |
4 changes: 2 additions & 2 deletions driver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})
connect.cc
connection_handler.cc
connection_proxy.cc
custom_endpoint_proxy.cc
custom_endpoint_info.cc
custom_endpoint_monitor.cc
custom_endpoint_proxy.cc
cursor.cc
desc.cc
dll.cc
Expand Down Expand Up @@ -148,9 +148,9 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})
cluster_topology_info.h
connection_handler.h
connection_proxy.h
custom_endpoint_proxy.h
custom_endpoint_info.h
custom_endpoint_monitor.h
custom_endpoint_proxy.h
driver.h
efm_proxy.h
error.h
Expand Down
4 changes: 2 additions & 2 deletions driver/cache_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
// along with this program. If not, see
// http://www.gnu.org/licenses/gpl-2.0.html.

#ifndef __CACHE_MAP__
#define __CACHE_MAP__
#ifndef __CACHE_MAP_H__
#define __CACHE_MAP_H__

#include <atomic>
#include <chrono>
Expand Down
22 changes: 22 additions & 0 deletions driver/cluster_topology_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "cluster_topology_info.h"

#include <stdexcept>
#include <algorithm>

/**
Initialize and return random number.
Expand Down Expand Up @@ -75,6 +76,20 @@ void CLUSTER_TOPOLOGY_INFO::add_host(std::shared_ptr<HOST_INFO> host_info) {
update_time();
}

void CLUSTER_TOPOLOGY_INFO::remove_host(std::shared_ptr<HOST_INFO> host_info) {
auto position = std::find(writers.begin(), writers.end(), host_info);
if (position != writers.end()) {
writers.erase(position);
}

position = std::find(readers.begin(), readers.end(), host_info);
if (position != readers.end()) {
readers.erase(position);
}
update_time();
}


size_t CLUSTER_TOPOLOGY_INFO::total_hosts() {
return writers.size() + readers.size();
}
Expand Down Expand Up @@ -136,6 +151,13 @@ std::vector<std::shared_ptr<HOST_INFO>> CLUSTER_TOPOLOGY_INFO::get_writers() {
return writers;
}

std::vector<std::shared_ptr<HOST_INFO>> CLUSTER_TOPOLOGY_INFO::get_instances() {
std::vector instances(writers);
instances.insert(instances.end(), writers.begin(), writers.end());

return instances;
}

std::shared_ptr<HOST_INFO> CLUSTER_TOPOLOGY_INFO::get_last_used_reader() {
return last_used_reader;
}
Expand Down
2 changes: 2 additions & 0 deletions driver/cluster_topology_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class CLUSTER_TOPOLOGY_INFO {
virtual ~CLUSTER_TOPOLOGY_INFO();

void add_host(std::shared_ptr<HOST_INFO> host_info);
void remove_host(std::shared_ptr<HOST_INFO> host_info);
size_t total_hosts();
size_t num_readers(); // return number of readers in the cluster
std::time_t time_last_updated();
Expand All @@ -58,6 +59,7 @@ class CLUSTER_TOPOLOGY_INFO {
std::shared_ptr<HOST_INFO> get_reader(int i);
std::vector<std::shared_ptr<HOST_INFO>> get_writers();
std::vector<std::shared_ptr<HOST_INFO>> get_readers();
std::vector<std::shared_ptr<HOST_INFO>> get_instances();

private:
int current_reader = -1;
Expand Down
5 changes: 1 addition & 4 deletions driver/custom_endpoint_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@
#include <aws/rds/model/DBClusterEndpoint.h>

#include <set>
#include <sstream>
#include <utility>

#include "MYODBC_MYSQL.h"
#include "stringutil.h"
#include "mylog.h"

/**
Expand Down
182 changes: 102 additions & 80 deletions driver/custom_endpoint_monitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,17 @@
// along with this program. If not, see
// http://www.gnu.org/licenses/gpl-2.0.html.

#include "custom_endpoint_monitor.h"

#include <aws/core/auth/AWSCredentialsProviderChain.h>
#include <aws/rds/model/DBClusterEndpoint.h>
#include <aws/rds/model/DescribeDBClusterEndpointsRequest.h>
#include <aws/rds/model/Filter.h>
#include <utility>
#include <vector>

#include "allowed_and_blocked_hosts.h"
#include "aws_sdk_helper.h"
#include "custom_endpoint_monitor.h"
#include "driver.h"
#include "monitor_service.h"
#include "mylog.h"

namespace {
Expand All @@ -47,109 +46,132 @@ AWS_SDK_HELPER SDK_HELPER;

CACHE_MAP<std::string, std::shared_ptr<CUSTOM_ENDPOINT_INFO>> CUSTOM_ENDPOINT_MONITOR::custom_endpoint_cache;

CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr<HOST_INFO>& custom_endpoint_host_info,
CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr<TOPOLOGY_SERVICE> topology_service,
const std::string& custom_endpoint_host,
const std::string& endpoint_identifier, const std::string& region,
DataSource* ds, bool enable_logging)
: custom_endpoint_host_info(custom_endpoint_host_info),
endpoint_identifier(endpoint_identifier),
region(region),
enable_logging(enable_logging) {
long long refresh_rate_nanos, bool enable_logging)
: topology_service(topology_service),
custom_endpoint_host(custom_endpoint_host),
endpoint_identifier(endpoint_identifier),
region(region),
refresh_rate_nanos(refresh_rate_nanos),
enable_logging(enable_logging) {
if (enable_logging) {
this->logger = init_log_file();
}

++SDK_HELPER;
this->run();
}

Aws::RDS::RDSClientConfiguration client_config;
if (!region.empty()) {
client_config.region = region;
#ifdef UNIT_TEST_BUILD
CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr<TOPOLOGY_SERVICE> topology_service,
const std::string& custom_endpoint_host,
const std::string& endpoint_identifier, const std::string& region,
long long refresh_rate_nanos, bool enable_logging,
std::shared_ptr<Aws::RDS::RDSClient> client)
: topology_service(topology_service),
custom_endpoint_host(custom_endpoint_host),
endpoint_identifier(endpoint_identifier),
region(region),
refresh_rate_nanos(refresh_rate_nanos),
enable_logging(enable_logging) {
if (enable_logging) {
this->logger = init_log_file();
}

this->rds_client = std::make_shared<Aws::RDS::RDSClient>(
Aws::Auth::DefaultAWSCredentialsProviderChain().GetAWSCredentials(), client_config);

this->run();
thread_pool = std::thread(&CUSTOM_ENDPOINT_MONITOR::run, this);
}
#endif

bool CUSTOM_ENDPOINT_MONITOR::should_dispose() { return true; }

bool CUSTOM_ENDPOINT_MONITOR::has_custom_endpoint_info() const {
auto default_val = std::shared_ptr<CUSTOM_ENDPOINT_INFO>(nullptr);
return custom_endpoint_cache.get(this->custom_endpoint_host_info->get_host(), default_val) != default_val;
return custom_endpoint_cache.get(this->custom_endpoint_host, default_val) != default_val;
}

void CUSTOM_ENDPOINT_MONITOR::run() {
this->thread_pool.resize(1);
this->thread_pool.push([=](int id) {
MYLOG_TRACE(this->logger, 0, "Starting custom endpoint monitor for '%s'",
this->custom_endpoint_host_info->get_host().c_str());

try {
while (!this->should_stop.load()) {
const std::chrono::time_point start = std::chrono::steady_clock::now();
Aws::RDS::Model::Filter filter;
filter.SetName("db-cluster-endpoint-type");
filter.SetValues({"custom"});

Aws::RDS::Model::DescribeDBClusterEndpointsRequest request;
request.SetDBClusterIdentifier(this->endpoint_identifier);
request.SetFilters({filter});
const auto response = this->rds_client->DescribeDBClusterEndpoints(request);

const auto custom_endpoints = response.GetResult().GetDBClusterEndpoints();
if (custom_endpoints.size() != 1) {
MYLOG_TRACE(this->logger, 0,
"Unexpected number of custom endpoints with endpoint identifier '%s' in region '%s'. Expected 1 "
"custom endpoint, but found %d. Endpoints: %s",
endpoint_identifier.c_str(), region.c_str(), custom_endpoints.size(),
this->get_endpoints_as_string(custom_endpoints).c_str());

std::this_thread::sleep_for(std::chrono::nanoseconds(this->refresh_rate_nanos));
continue;
}
const std::shared_ptr<CUSTOM_ENDPOINT_INFO> endpoint_info =
++SDK_HELPER;

Aws::RDS::RDSClientConfiguration client_config;
if (!region.empty()) {
client_config.region = region;
}

const Aws::RDS::RDSClient rds_client(Aws::Auth::DefaultAWSCredentialsProviderChain().GetAWSCredentials(),
client_config);
Aws::RDS::Model::Filter filter;
filter.SetName("db-cluster-endpoint-type");
filter.AddValues("custom");

Aws::RDS::Model::DescribeDBClusterEndpointsRequest request;
request.SetDBClusterEndpointIdentifier(this->endpoint_identifier);
// TODO: Investigate why filters returns `InvalidParameterCombination` error saying filter values are null.
// request.AddFilters(filter);

MYLOG_TRACE(this->logger, 0, "Starting custom endpoint monitor for '%s'", this->custom_endpoint_host.c_str());

try {
while (!should_stop) {
const std::chrono::time_point start = std::chrono::steady_clock::now();
const auto response = rds_client.DescribeDBClusterEndpoints(request);
const auto custom_endpoints = response.GetResult().GetDBClusterEndpoints();
if (custom_endpoints.size() != 1) {
MYLOG_TRACE(this->logger, 0,
"Unexpected number of custom endpoints with endpoint identifier '%s' in region '%s'. Expected 1 "
"custom endpoint, but found %d. Endpoints: %s",
endpoint_identifier.c_str(), region.c_str(), custom_endpoints.size(),
this->get_endpoints_as_string(custom_endpoints).c_str());

std::this_thread::sleep_for(std::chrono::nanoseconds(this->refresh_rate_nanos));
continue;
}
const std::shared_ptr<CUSTOM_ENDPOINT_INFO> endpoint_info =
CUSTOM_ENDPOINT_INFO::from_db_cluster_endpoint(custom_endpoints[0]);
const std::shared_ptr<CUSTOM_ENDPOINT_INFO> cache_endpoint_info =
custom_endpoint_cache.get(this->custom_endpoint_host_info->get_host(), nullptr);
const std::shared_ptr<CUSTOM_ENDPOINT_INFO> cache_endpoint_info =
custom_endpoint_cache.get(this->custom_endpoint_host, nullptr);

if (cache_endpoint_info != nullptr && cache_endpoint_info == endpoint_info) {
const long long elapsed_time =
if (cache_endpoint_info != nullptr && cache_endpoint_info == endpoint_info) {
const long long elapsed_time =
std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::steady_clock::now() - start).count();
std::this_thread::sleep_for(
std::this_thread::sleep_for(
std::chrono::nanoseconds(std::max(static_cast<long long>(0), this->refresh_rate_nanos - elapsed_time)));
continue;
}
continue;
}

MYLOG_TRACE(this->logger, 0, "Detected change in custom endpoint info for '%s':\n{%s}",
custom_endpoint_host_info->get_host().c_str(), endpoint_info->to_string().c_str());
MYLOG_TRACE(this->logger, 0, "Detected change in custom endpoint info for '%s':\n{%s}",
custom_endpoint_host.c_str(), endpoint_info->to_string().c_str());

// The custom endpoint info has changed, so we need to update the set of allowed/blocked hosts.
std::shared_ptr<ALLOWED_AND_BLOCKED_HOSTS> allowed_and_blocked_hosts;
if (endpoint_info->get_member_list_type() == STATIC_LIST) {
allowed_and_blocked_hosts =
// The custom endpoint info has changed, so we need to update the set of allowed/blocked hosts.
std::shared_ptr<ALLOWED_AND_BLOCKED_HOSTS> allowed_and_blocked_hosts;
if (endpoint_info->get_member_list_type() == STATIC_LIST) {
allowed_and_blocked_hosts =
std::make_shared<ALLOWED_AND_BLOCKED_HOSTS>(endpoint_info->get_static_members(), std::set<std::string>());
} else {
allowed_and_blocked_hosts =
} else {
allowed_and_blocked_hosts =
std::make_shared<ALLOWED_AND_BLOCKED_HOSTS>(std::set<std::string>(), endpoint_info->get_excluded_members());
}
}

custom_endpoint_cache.put(this->custom_endpoint_host_info->get_host(), endpoint_info,
CUSTOM_ENDPOINT_INFO_EXPIRATION_NANOS);
const long long elapsed_time =
this->topology_service->set_allowed_and_blocked_hosts(allowed_and_blocked_hosts);
custom_endpoint_cache.put(this->custom_endpoint_host, endpoint_info, CUSTOM_ENDPOINT_INFO_EXPIRATION_NANOS);
const long long elapsed_time =
std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::steady_clock::now() - start).count();
std::this_thread::sleep_for(
std::this_thread::sleep_for(
std::chrono::nanoseconds(std::max(static_cast<long long>(0), this->refresh_rate_nanos - elapsed_time)));
}

} catch (const std::exception& e) {
// Log and continue monitoring.
MYLOG_TRACE(this->logger, 0, "Error while monitoring custom endpoint: %s", e.what());
}
});

--SDK_HELPER;
} catch (const std::exception& e) {
// Log and continue monitoring.
--SDK_HELPER;
MYLOG_TRACE(this->logger, 0, "Error while monitoring custom endpoint: %s", e.what());
}

should_stop = true;
}

std::string CUSTOM_ENDPOINT_MONITOR::get_endpoints_as_string(
const std::vector<Aws::RDS::Model::DBClusterEndpoint>& custom_endpoints) {
const std::vector<Aws::RDS::Model::DBClusterEndpoint>& custom_endpoints) {
if (custom_endpoints.empty()) {
return "<no endpoints>";
}
Expand All @@ -168,12 +190,12 @@ std::string CUSTOM_ENDPOINT_MONITOR::get_endpoints_as_string(
}

void CUSTOM_ENDPOINT_MONITOR::stop() {
this->should_stop.store(true);
this->thread_pool.stop(true);
this->thread_pool.resize(0);
custom_endpoint_cache.remove(this->custom_endpoint_host_info->get_host());
--SDK_HELPER;
MYLOG_TRACE(this->logger, 0, "Stopped custom endpoint monitor for '%s'", this->custom_endpoint_host_info->get_host().c_str());
should_stop = true;
//thread_pool.stop(false);
//thread_pool.resize(0);
thread_pool.join();
custom_endpoint_cache.remove(this->custom_endpoint_host);
MYLOG_TRACE(this->logger, 0, "Stopped custom endpoint monitor for '%s'", this->custom_endpoint_host.c_str());
}

void CUSTOM_ENDPOINT_MONITOR::clear_cache() { custom_endpoint_cache.clear(); }
Loading

0 comments on commit 290d4d7

Please sign in to comment.