Skip to content

Commit

Permalink
test: custom endpoints unit tests and integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
karenc-bq committed Jan 21, 2025
1 parent 07f6b7d commit 4bbc8c9
Show file tree
Hide file tree
Showing 35 changed files with 1,336 additions and 636 deletions.
1 change: 1 addition & 0 deletions .github/workflows/failover.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name: Failover Unit Tests

on:
workflow_dispatch:
push:
branches:
- main
Expand Down
24 changes: 24 additions & 0 deletions docs/using-the-aws-driver/CustomEndpoint.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Custom Endpoint Support

The Custom Endpoint support allows client application to use the driver with RDS custom endpoints. When the Custom Endpoint feature is enabled, the driver will analyse custom endpoint information to ensure instances used in connections are part of the custom endpoint being used. This includes connections used in failover.

## How to use the Driver with Custom Endpoint

### Enabling the Custom Endpoint Feature

1. If needed, create a custom endpoint using the AWS RDS Console:
- If needed, review the documentation about [creating a custom endpoint](https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/aurora-custom-endpoint-creating.html).
2. Set `ENABLE_CUSTOM_ENDPOINT_MONITORING` to `TRUE` to enable custom endpoint support.
3. If you are using the failover plugin, set the failover parameter `FAILOVER_MODE` according to the custom endpoint type. For example, if the custom endpoint you are using is of type `READER`, you can set `FAILOVER_MODE` to `strict-reader`, or if it is of type `ANY`, you can set `FAILOVER_MODE` to `reader-or-writer`.
4. Specify parameters that are required or specific to your case.

### Custom Endpoint Plugin Parameters

| Parameter | Value | Required | Description | Default Value | Example Value |
| ------------------------------------------ | :----: | :------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------- | ------------- |
| `ENABLE_CUSTOM_ENDPOINT_MONITORING` | bool | No | Set to TRUE to enable custom endpoint support. | `FALSE` | `TRUE` |
| `CUSTOM_ENDPOINT_REGION` | string | No | The region of the cluster's custom endpoints. If not specified, the region will be parsed from the URL. | `N/A` | `us-west-1` |
| `CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS` | long | No | Controls how frequently custom endpoint monitors fetch custom endpoint info, in milliseconds. | `30000` | `20000` |
| `CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS` | long | No | Controls how long a monitor should run without use before expiring and being removed, in milliseconds. | `900000` (15 minutes) | `600000` |
| `WAIT_FOR_CUSTOM_ENDPOINT_INFO` | bool | No | Controls whether to wait for custom endpoint info to become available before connecting or executing a method. Waiting is only necessary if a connection to a given custom endpoint has not been opened or used recently. Note that disabling this may result in occasional connections to instances outside of the custom endpoint. | `true` | `true` |
| `WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS` | long | No | Controls the maximum amount of time that the plugin will wait for custom endpoint info to be made available by the custom endpoint monitor, in milliseconds. | `5000` | `7000` |
4 changes: 2 additions & 2 deletions driver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})
connect.cc
connection_handler.cc
connection_proxy.cc
custom_endpoint_proxy.cc
custom_endpoint_info.cc
custom_endpoint_monitor.cc
custom_endpoint_proxy.cc
cursor.cc
desc.cc
dll.cc
Expand Down Expand Up @@ -148,9 +148,9 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})
cluster_topology_info.h
connection_handler.h
connection_proxy.h
custom_endpoint_proxy.h
custom_endpoint_info.h
custom_endpoint_monitor.h
custom_endpoint_proxy.h
driver.h
efm_proxy.h
error.h
Expand Down
4 changes: 2 additions & 2 deletions driver/cache_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
// along with this program. If not, see
// http://www.gnu.org/licenses/gpl-2.0.html.

#ifndef __CACHE_MAP__
#define __CACHE_MAP__
#ifndef __CACHE_MAP_H__
#define __CACHE_MAP_H__

#include <atomic>
#include <chrono>
Expand Down
22 changes: 22 additions & 0 deletions driver/cluster_topology_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "cluster_topology_info.h"

#include <stdexcept>
#include <algorithm>

/**
Initialize and return random number.
Expand Down Expand Up @@ -75,6 +76,20 @@ void CLUSTER_TOPOLOGY_INFO::add_host(std::shared_ptr<HOST_INFO> host_info) {
update_time();
}

void CLUSTER_TOPOLOGY_INFO::remove_host(std::shared_ptr<HOST_INFO> host_info) {
auto position = std::find(writers.begin(), writers.end(), host_info);
if (position != writers.end()) {
writers.erase(position);
}

position = std::find(readers.begin(), readers.end(), host_info);
if (position != readers.end()) {
readers.erase(position);
}
update_time();
}


size_t CLUSTER_TOPOLOGY_INFO::total_hosts() {
return writers.size() + readers.size();
}
Expand Down Expand Up @@ -136,6 +151,13 @@ std::vector<std::shared_ptr<HOST_INFO>> CLUSTER_TOPOLOGY_INFO::get_writers() {
return writers;
}

std::vector<std::shared_ptr<HOST_INFO>> CLUSTER_TOPOLOGY_INFO::get_instances() {
std::vector instances(writers);
instances.insert(instances.end(), writers.begin(), writers.end());

return instances;
}

std::shared_ptr<HOST_INFO> CLUSTER_TOPOLOGY_INFO::get_last_used_reader() {
return last_used_reader;
}
Expand Down
2 changes: 2 additions & 0 deletions driver/cluster_topology_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class CLUSTER_TOPOLOGY_INFO {
virtual ~CLUSTER_TOPOLOGY_INFO();

void add_host(std::shared_ptr<HOST_INFO> host_info);
void remove_host(std::shared_ptr<HOST_INFO> host_info);
size_t total_hosts();
size_t num_readers(); // return number of readers in the cluster
std::time_t time_last_updated();
Expand All @@ -58,6 +59,7 @@ class CLUSTER_TOPOLOGY_INFO {
std::shared_ptr<HOST_INFO> get_reader(int i);
std::vector<std::shared_ptr<HOST_INFO>> get_writers();
std::vector<std::shared_ptr<HOST_INFO>> get_readers();
std::vector<std::shared_ptr<HOST_INFO>> get_instances();

private:
int current_reader = -1;
Expand Down
5 changes: 1 addition & 4 deletions driver/custom_endpoint_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@
#include <aws/rds/model/DBClusterEndpoint.h>

#include <set>
#include <sstream>
#include <utility>

#include "MYODBC_MYSQL.h"
#include "stringutil.h"
#include "mylog.h"

/**
Expand Down
80 changes: 51 additions & 29 deletions driver/custom_endpoint_monitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,17 @@
// along with this program. If not, see
// http://www.gnu.org/licenses/gpl-2.0.html.

#include "custom_endpoint_monitor.h"

#include <aws/core/auth/AWSCredentialsProviderChain.h>
#include <aws/rds/model/DBClusterEndpoint.h>
#include <aws/rds/model/DescribeDBClusterEndpointsRequest.h>
#include <aws/rds/model/Filter.h>
#include <utility>
#include <vector>

#include "allowed_and_blocked_hosts.h"
#include "aws_sdk_helper.h"
#include "custom_endpoint_monitor.h"
#include "driver.h"
#include "monitor_service.h"
#include "mylog.h"

namespace {
Expand All @@ -47,13 +46,16 @@ AWS_SDK_HELPER SDK_HELPER;

CACHE_MAP<std::string, std::shared_ptr<CUSTOM_ENDPOINT_INFO>> CUSTOM_ENDPOINT_MONITOR::custom_endpoint_cache;

CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr<HOST_INFO>& custom_endpoint_host_info,
CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr<TOPOLOGY_SERVICE> topology_service,
const std::string& custom_endpoint_host,
const std::string& endpoint_identifier, const std::string& region,
DataSource* ds, bool enable_logging)
: custom_endpoint_host_info(custom_endpoint_host_info),
endpoint_identifier(endpoint_identifier),
region(region),
enable_logging(enable_logging) {
long long refresh_rate_nanos, bool enable_logging)
: topology_service(topology_service),
custom_endpoint_host(custom_endpoint_host),
endpoint_identifier(endpoint_identifier),
region(region),
refresh_rate_nanos(refresh_rate_nanos),
enable_logging(enable_logging) {
if (enable_logging) {
this->logger = init_log_file();
}
Expand All @@ -66,23 +68,42 @@ CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr<HOST_INFO
}

this->rds_client = std::make_shared<Aws::RDS::RDSClient>(
Aws::Auth::DefaultAWSCredentialsProviderChain().GetAWSCredentials(), client_config);
Aws::Auth::DefaultAWSCredentialsProviderChain().GetAWSCredentials(), client_config);

this->run();
}

#ifdef UNIT_TEST_BUILD
CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr<TOPOLOGY_SERVICE> topology_service,
const std::string& custom_endpoint_host,
const std::string& endpoint_identifier, const std::string& region,
long long refresh_rate_nanos, bool enable_logging,
std::shared_ptr<Aws::RDS::RDSClient> client)
: topology_service(topology_service),
custom_endpoint_host(custom_endpoint_host),
endpoint_identifier(endpoint_identifier),
region(region),
refresh_rate_nanos(refresh_rate_nanos),
enable_logging(enable_logging),
rds_client(std::move(client)) {
if (enable_logging) {
this->logger = init_log_file();
}
this->run();
}
#endif

bool CUSTOM_ENDPOINT_MONITOR::should_dispose() { return true; }

bool CUSTOM_ENDPOINT_MONITOR::has_custom_endpoint_info() const {
auto default_val = std::shared_ptr<CUSTOM_ENDPOINT_INFO>(nullptr);
return custom_endpoint_cache.get(this->custom_endpoint_host_info->get_host(), default_val) != default_val;
return custom_endpoint_cache.get(this->custom_endpoint_host, default_val) != default_val;
}

void CUSTOM_ENDPOINT_MONITOR::run() {
this->thread_pool.resize(1);
this->thread_pool.push([=](int id) {
MYLOG_TRACE(this->logger, 0, "Starting custom endpoint monitor for '%s'",
this->custom_endpoint_host_info->get_host().c_str());
MYLOG_TRACE(this->logger, 0, "Starting custom endpoint monitor for '%s'", this->custom_endpoint_host.c_str());

try {
while (!this->should_stop.load()) {
Expand All @@ -92,7 +113,7 @@ void CUSTOM_ENDPOINT_MONITOR::run() {
filter.SetValues({"custom"});

Aws::RDS::Model::DescribeDBClusterEndpointsRequest request;
request.SetDBClusterIdentifier(this->endpoint_identifier);
request.SetDBClusterEndpointIdentifier(this->endpoint_identifier);
request.SetFilters({filter});
const auto response = this->rds_client->DescribeDBClusterEndpoints(request);

Expand All @@ -108,37 +129,37 @@ void CUSTOM_ENDPOINT_MONITOR::run() {
continue;
}
const std::shared_ptr<CUSTOM_ENDPOINT_INFO> endpoint_info =
CUSTOM_ENDPOINT_INFO::from_db_cluster_endpoint(custom_endpoints[0]);
CUSTOM_ENDPOINT_INFO::from_db_cluster_endpoint(custom_endpoints[0]);
const std::shared_ptr<CUSTOM_ENDPOINT_INFO> cache_endpoint_info =
custom_endpoint_cache.get(this->custom_endpoint_host_info->get_host(), nullptr);
custom_endpoint_cache.get(this->custom_endpoint_host, nullptr);

if (cache_endpoint_info != nullptr && cache_endpoint_info == endpoint_info) {
const long long elapsed_time =
std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::steady_clock::now() - start).count();
std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::steady_clock::now() - start).count();
std::this_thread::sleep_for(
std::chrono::nanoseconds(std::max(static_cast<long long>(0), this->refresh_rate_nanos - elapsed_time)));
std::chrono::nanoseconds(std::max(static_cast<long long>(0), this->refresh_rate_nanos - elapsed_time)));
continue;
}

MYLOG_TRACE(this->logger, 0, "Detected change in custom endpoint info for '%s':\n{%s}",
custom_endpoint_host_info->get_host().c_str(), endpoint_info->to_string().c_str());
custom_endpoint_host.c_str(), endpoint_info->to_string().c_str());

// The custom endpoint info has changed, so we need to update the set of allowed/blocked hosts.
std::shared_ptr<ALLOWED_AND_BLOCKED_HOSTS> allowed_and_blocked_hosts;
if (endpoint_info->get_member_list_type() == STATIC_LIST) {
allowed_and_blocked_hosts =
std::make_shared<ALLOWED_AND_BLOCKED_HOSTS>(endpoint_info->get_static_members(), std::set<std::string>());
std::make_shared<ALLOWED_AND_BLOCKED_HOSTS>(endpoint_info->get_static_members(), std::set<std::string>());
} else {
allowed_and_blocked_hosts =
std::make_shared<ALLOWED_AND_BLOCKED_HOSTS>(std::set<std::string>(), endpoint_info->get_excluded_members());
allowed_and_blocked_hosts = std::make_shared<ALLOWED_AND_BLOCKED_HOSTS>(
std::set<std::string>(), endpoint_info->get_excluded_members());
}

custom_endpoint_cache.put(this->custom_endpoint_host_info->get_host(), endpoint_info,
CUSTOM_ENDPOINT_INFO_EXPIRATION_NANOS);
this->topology_service->set_allowed_and_blocked_hosts(allowed_and_blocked_hosts);
custom_endpoint_cache.put(this->custom_endpoint_host, endpoint_info, CUSTOM_ENDPOINT_INFO_EXPIRATION_NANOS);
const long long elapsed_time =
std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::steady_clock::now() - start).count();
std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::steady_clock::now() - start).count();
std::this_thread::sleep_for(
std::chrono::nanoseconds(std::max(static_cast<long long>(0), this->refresh_rate_nanos - elapsed_time)));
std::chrono::nanoseconds(std::max(static_cast<long long>(0), this->refresh_rate_nanos - elapsed_time)));
}

} catch (const std::exception& e) {
Expand All @@ -149,7 +170,7 @@ void CUSTOM_ENDPOINT_MONITOR::run() {
}

std::string CUSTOM_ENDPOINT_MONITOR::get_endpoints_as_string(
const std::vector<Aws::RDS::Model::DBClusterEndpoint>& custom_endpoints) {
const std::vector<Aws::RDS::Model::DBClusterEndpoint>& custom_endpoints) {
if (custom_endpoints.empty()) {
return "<no endpoints>";
}
Expand All @@ -171,9 +192,10 @@ void CUSTOM_ENDPOINT_MONITOR::stop() {
this->should_stop.store(true);
this->thread_pool.stop(true);
this->thread_pool.resize(0);
custom_endpoint_cache.remove(this->custom_endpoint_host_info->get_host());
custom_endpoint_cache.remove(this->custom_endpoint_host);
this->rds_client.reset();
--SDK_HELPER;
MYLOG_TRACE(this->logger, 0, "Stopped custom endpoint monitor for '%s'", this->custom_endpoint_host_info->get_host().c_str());
MYLOG_TRACE(this->logger, 0, "Stopped custom endpoint monitor for '%s'", this->custom_endpoint_host.c_str());
}

void CUSTOM_ENDPOINT_MONITOR::clear_cache() { custom_endpoint_cache.clear(); }
24 changes: 19 additions & 5 deletions driver/custom_endpoint_monitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,23 @@

#include <ctpl_stl.h>
#include "cache_map.h"
#include "connection_handler.h"
#include "connection_proxy.h"
#include "custom_endpoint_info.h"
#include "host_info.h"
#include "topology_service.h"

class CUSTOM_ENDPOINT_MONITOR : public std::enable_shared_from_this<CUSTOM_ENDPOINT_MONITOR> {
public:
CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr<HOST_INFO>& custom_endpoint_host_info, const std::string& endpoint_identifier,
const std::string& region, DataSource* ds, bool enable_logging = false);
CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr<TOPOLOGY_SERVICE> topology_service,
const std::string& custom_endpoint_host, const std::string& endpoint_identifier,
const std::string& region, long long refresh_rate_nanos, bool enable_logging = false);
#ifdef UNIT_TEST_BUILD
CUSTOM_ENDPOINT_MONITOR() = default;
CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr<TOPOLOGY_SERVICE> topology_service,
const std::string& custom_endpoint_host, const std::string& endpoint_identifier,
const std::string& region, long long refresh_rate_nanos, bool enable_logging,
std::shared_ptr<Aws::RDS::RDSClient> client);
#endif

~CUSTOM_ENDPOINT_MONITOR() = default;

static bool should_dispose();
Expand All @@ -54,7 +62,7 @@ class CUSTOM_ENDPOINT_MONITOR : public std::enable_shared_from_this<CUSTOM_ENDPO
protected:
static CACHE_MAP<std::string, std::shared_ptr<CUSTOM_ENDPOINT_INFO>> custom_endpoint_cache;
static constexpr long long CUSTOM_ENDPOINT_INFO_EXPIRATION_NANOS = 300000000000; // 5 minutes
std::shared_ptr<HOST_INFO> custom_endpoint_host_info;
std::string custom_endpoint_host;
std::string endpoint_identifier;
std::string region;
long long refresh_rate_nanos;
Expand All @@ -63,9 +71,15 @@ class CUSTOM_ENDPOINT_MONITOR : public std::enable_shared_from_this<CUSTOM_ENDPO
ctpl::thread_pool thread_pool;
std::atomic_bool should_stop{false};
std::shared_ptr<Aws::RDS::RDSClient> rds_client;
std::shared_ptr<TOPOLOGY_SERVICE> topology_service;

private:
static std::string get_endpoints_as_string(const std::vector<Aws::RDS::Model::DBClusterEndpoint>& custom_endpoints);

#ifdef UNIT_TEST_BUILD
// Allows for testing private/protected methods
friend class TEST_UTILS;
#endif
};

#endif
Loading

0 comments on commit 4bbc8c9

Please sign in to comment.