Skip to content

Commit

Permalink
feat: custom endpoint support (#218)
Browse files Browse the repository at this point in the history
  • Loading branch information
karenc-bq authored Jan 28, 2025
1 parent 45f0904 commit f65ffec
Show file tree
Hide file tree
Showing 52 changed files with 2,056 additions and 270 deletions.
1 change: 1 addition & 0 deletions .github/workflows/failover.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name: Failover Unit Tests

on:
workflow_dispatch:
push:
branches:
- main
Expand Down
Binary file added docs/images/sample_custom_endpoints_dsn.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
26 changes: 26 additions & 0 deletions docs/using-the-aws-driver/CustomEndpoint.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Custom Endpoint Support

The Custom Endpoint support allows client application to use the driver with [RDS custom endpoints](https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/Aurora.Endpoints.Custom.html). When the Custom Endpoint feature is enabled, the driver will analyse custom endpoint information to ensure instances used in connections are part of the custom endpoint being used. This includes connections used in failover.

## How to use the Driver with Custom Endpoint

### Enabling the Custom Endpoint Feature

1. If needed, create a custom endpoint using the AWS RDS Console:
- If needed, review the documentation about [creating a custom endpoint](https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/aurora-custom-endpoint-creating.html).
2. Set `ENABLE_CUSTOM_ENDPOINT_MONITORING` to `TRUE` to enable custom endpoint support.
3. If you are using the failover plugin, set the failover parameter `FAILOVER_MODE` according to the custom endpoint type. For example, if the custom endpoint you are using is of type `READER`, you can set `FAILOVER_MODE` to `strict reader`, or if it is of type `ANY`, you can set `FAILOVER_MODE` to `reader or writer`.
4. Specify parameters that are required or specific to your case.

### Custom Endpoint Plugin Parameters

| Parameter | Value | Required | Description | Default Value | Example Value |
| ------------------------------------------ | :----: | :------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------c--------------------------------------------------------------------------------------------- | --------------------- | ------------- |
| `ENABLE_CUSTOM_ENDPOINT_MONITORING` | bool | No | Set to TRUE to enable custom endpoint support. | `FALSE` | `TRUE` |
| `CUSTOM_ENDPOINT_REGION` | string | No | The region of the cluster's custom endpoints. If not specified, the region will be parsed from the URL. | `N/A` | `us-west-1` |
| `CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS` | long | No | Controls how frequently custom endpoint monitors fetch custom endpoint info, in milliseconds. | `30000` | `20000` |
| `CUSTOM_ENDPOINT_MONITOR_EXPIRATION_MS` | long | No | Controls how long a monitor should run without use before expiring and being removed, in milliseconds. | `900000` (15 minutes) | `600000` |
| `WAIT_FOR_CUSTOM_ENDPOINT_INFO` | bool | No | Controls whether to wait for custom endpoint info to become available before connecting or executing a method. Waiting is only necessary if a connection to a given custom endpoint has not been opened or used recently. Note that disabling this may result in occasional connections to instances outside of the custom endpoint. | `TRUE` | `TRUE` |
| `WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS` | long | No | Controls the maximum amount of time that the plugin will wait for custom endpoint info to be made available by the custom endpoint monitor, in milliseconds. | `5000` | `7000` |

![sample_custom_endpoints_dsn](../images/sample_custom_endpoints_dsn.png)
9 changes: 9 additions & 0 deletions driver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})
auth_util.cc
aws_sdk_helper.cc
base_metrics_holder.cc
cache_map.cc
catalog.cc
catalog_no_i_s.cc
cluster_topology_info.cc
Expand All @@ -72,6 +73,9 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})
connect.cc
connection_handler.cc
connection_proxy.cc
custom_endpoint_info.cc
custom_endpoint_monitor.cc
custom_endpoint_proxy.cc
cursor.cc
desc.cc
dll.cc
Expand Down Expand Up @@ -131,9 +135,11 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})
CONFIGURE_FILE(${CMAKE_SOURCE_DIR}/driver/driver.rc.cmake ${CMAKE_SOURCE_DIR}/driver/driver${CONNECTOR_DRIVER_TYPE_SHORT}.rc @ONLY)
SET(DRIVER_SRCS ${DRIVER_SRCS} driver${CONNECTOR_DRIVER_TYPE_SHORT}.def driver${CONNECTOR_DRIVER_TYPE_SHORT}.rc
adfs_proxy.h
allowed_and_blocked_hosts.h
auth_util.h
aws_sdk_helper.h
base_metrics_holder.h
cache_map.h
catalog.h
cluster_aware_hit_metrics_holder.h
cluster_aware_metrics_container.h
Expand All @@ -142,6 +148,9 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})
cluster_topology_info.h
connection_handler.h
connection_proxy.h
custom_endpoint_info.h
custom_endpoint_monitor.h
custom_endpoint_proxy.h
driver.h
efm_proxy.h
error.h
Expand Down
74 changes: 74 additions & 0 deletions driver/allowed_and_blocked_hosts.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License, version 2.0
// (GPLv2), as published by the Free Software Foundation, with the
// following additional permissions:
//
// This program is distributed with certain software that is licensed
// under separate terms, as designated in a particular file or component
// or in the license documentation. Without limiting your rights under
// the GPLv2, the authors of this program hereby grant you an additional
// permission to link the program and your derivative works with the
// separately licensed software that they have included with the program.
//
// Without limiting the foregoing grant of rights under the GPLv2 and
// additional permission as to separately licensed software, this
// program is also subject to the Universal FOSS Exception, version 1.0,
// a copy of which can be found along with its FAQ at
// http://oss.oracle.com/licenses/universal-foss-exception.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
// See the GNU General Public License, version 2.0, for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see
// http://www.gnu.org/licenses/gpl-2.0.html.

#ifndef __ALLOWED_AND_BLOCKED_HOSTS__
#define __ALLOWED_AND_BLOCKED_HOSTS__

#include <set>
#include <string>

/**
* Represents the allowed and blocked hosts for connections.
*/
class ALLOWED_AND_BLOCKED_HOSTS {
public:
/**
* Constructs an AllowedAndBlockedHosts instance with the specified allowed and blocked host IDs.
* @param allowed_host_ids The set of allowed host IDs for connections. If null or empty, all host IDs that are not in
* `blocked_host_ids` are allowed.
* @param blocked_host_ids The set of blocked host IDs for connections. If null or empty, all host IDs in
* `allowed_host_ids` are allowed. If `allowed_host_ids` is also null or empty, there
* are no restrictions on which hosts are allowed.
*/
ALLOWED_AND_BLOCKED_HOSTS(const std::set<std::string>& allowed_host_ids,
const std::set<std::string>& blocked_host_ids)
: allowed_host_ids(allowed_host_ids), blocked_host_ids(blocked_host_ids){};

/**
* Returns the set of allowed host IDs for connections. If null or empty, all host IDs that are not in
* `blocked_host_ids` are allowed.
*
* @return the set of allowed host IDs for connections.
*/
std::set<std::string> get_allowed_host_ids() { return this->allowed_host_ids; };

/**
* Returns the set of blocked host IDs for connections. If null or empty, all host IDs in `allowed_host_ids`
* are allowed. If `allowed_host_ids` is also null or empty, there are no restrictions on which hosts are allowed.
*
* @return the set of blocked host IDs for connections.
*/
std::set<std::string> get_blocked_host_ids() { return this->blocked_host_ids; };

private:
std::set<std::string> allowed_host_ids;
std::set<std::string> blocked_host_ids;
};

#endif
2 changes: 1 addition & 1 deletion driver/auth_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ std::pair<std::string, bool> AUTH_UTIL::get_auth_token(std::unordered_map<std::s
}

std::string auth_token;
const std::string cache_key = this->build_cache_key(host, region, port, user);
const std::string cache_key = build_cache_key(host, region, port, user);
bool using_cached_token = false;

{
Expand Down
95 changes: 95 additions & 0 deletions driver/cache_map.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License, version 2.0
// (GPLv2), as published by the Free Software Foundation, with the
// following additional permissions:
//
// This program is distributed with certain software that is licensed
// under separate terms, as designated in a particular file or component
// or in the license documentation. Without limiting your rights under
// the GPLv2, the authors of this program hereby grant you an additional
// permission to link the program and your derivative works with the
// separately licensed software that they have included with the program.
//
// Without limiting the foregoing grant of rights under the GPLv2 and
// additional permission as to separately licensed software, this
// program is also subject to the Universal FOSS Exception, version 1.0,
// a copy of which can be found along with its FAQ at
// http://oss.oracle.com/licenses/universal-foss-exception.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
// See the GNU General Public License, version 2.0, for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see
// http://www.gnu.org/licenses/gpl-2.0.html.

#include "cache_map.h"

#include <utility>

#include "custom_endpoint_info.h"

template <class K, class V>
void CACHE_MAP<K, V>::put(K key, V value, long long item_expiration_nanos) {
this->cache[key] = std::make_shared<CACHE_ITEM>(
value, std::chrono::steady_clock::now() + std::chrono::nanoseconds(item_expiration_nanos));
this->clean_up();
}

template <class K, class V>
V CACHE_MAP<K, V>::get(K key, V default_value) {
if (cache.count(key) > 0 && !cache[key]->is_expired()) {
return this->cache[key]->item;
}
return default_value;
}

template <class K, class V>
V CACHE_MAP<K, V>::get(K key, V default_value, long long item_expiration_nanos) {
if (cache.count(key) == 0 || this->cache[key]->is_expired()) {
this->put(key, std::move(default_value), item_expiration_nanos);
}
return this->cache[key]->item;
}

template <class K, class V>
void CACHE_MAP<K, V>::remove(K key) {
if (this->cache.count(key)) {
this->cache.erase(key);
}
this->clean_up();
}

template <class K, class V>
int CACHE_MAP<K, V>::size() {
return this->cache.size();
}

template <class K, class V>
void CACHE_MAP<K, V>::clear() {
this->cache.clear();
}

template <class K, class V>
void CACHE_MAP<K, V>::clean_up() {
if (std::chrono::steady_clock::now() > this->clean_up_time_nanos.load()) {
this->clean_up_time_nanos =
std::chrono::steady_clock::now() + std::chrono::nanoseconds(this->clean_up_time_interval_nanos);
std::vector<K> keys;
keys.reserve(this->cache.size());
for (auto& [key, cache_item] : this->cache) {
keys.push_back(key);
}
for (const auto& key : keys) {
if (this->cache[key]->is_expired()) {
this->cache.erase(key);
}
}
}
}

template class CACHE_MAP<std::string, std::shared_ptr<CUSTOM_ENDPOINT_INFO>>;
74 changes: 74 additions & 0 deletions driver/cache_map.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License, version 2.0
// (GPLv2), as published by the Free Software Foundation, with the
// following additional permissions:
//
// This program is distributed with certain software that is licensed
// under separate terms, as designated in a particular file or component
// or in the license documentation. Without limiting your rights under
// the GPLv2, the authors of this program hereby grant you an additional
// permission to link the program and your derivative works with the
// separately licensed software that they have included with the program.
//
// Without limiting the foregoing grant of rights under the GPLv2 and
// additional permission as to separately licensed software, this
// program is also subject to the Universal FOSS Exception, version 1.0,
// a copy of which can be found along with its FAQ at
// http://oss.oracle.com/licenses/universal-foss-exception.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
// See the GNU General Public License, version 2.0, for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see
// http://www.gnu.org/licenses/gpl-2.0.html.

#ifndef __CACHE_MAP_H__
#define __CACHE_MAP_H__

#include <atomic>
#include <chrono>
#include <memory>
#include <unordered_map>

template <class K, class V>
class CACHE_MAP {
public:
class CACHE_ITEM {
public:
CACHE_ITEM() = default;
CACHE_ITEM(V item, std::chrono::steady_clock::time_point expiration_time)
: item(item), expiration_time(expiration_time){};
~CACHE_ITEM() = default;
V item;

bool is_expired() { return std::chrono::steady_clock::now() > this->expiration_time; }

private:
std::chrono::steady_clock::time_point expiration_time;
};

CACHE_MAP() = default;
~CACHE_MAP() = default;

void put(K key, V value, long long item_expiration_nanos);
V get(K key, V default_value);
V get(K key, V default_value, long long item_expiration_nanos);
void remove(K key);
int size();
void clear();

protected:
void clean_up();
const long long clean_up_time_interval_nanos = 60000000000; // 10 minute
std::atomic<std::chrono::steady_clock::time_point> clean_up_time_nanos;

private:
std::unordered_map<K, std::shared_ptr<CACHE_ITEM>> cache;
};

#endif
21 changes: 21 additions & 0 deletions driver/cluster_topology_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "cluster_topology_info.h"

#include <stdexcept>
#include <algorithm>

/**
Initialize and return random number.
Expand Down Expand Up @@ -75,6 +76,19 @@ void CLUSTER_TOPOLOGY_INFO::add_host(std::shared_ptr<HOST_INFO> host_info) {
update_time();
}

void CLUSTER_TOPOLOGY_INFO::remove_host(std::shared_ptr<HOST_INFO> host_info) {
auto position = std::find(writers.begin(), writers.end(), host_info);
if (position != writers.end()) {
writers.erase(position);
}

position = std::find(readers.begin(), readers.end(), host_info);
if (position != readers.end()) {
readers.erase(position);
}
update_time();
}

size_t CLUSTER_TOPOLOGY_INFO::total_hosts() {
return writers.size() + readers.size();
}
Expand Down Expand Up @@ -136,6 +150,13 @@ std::vector<std::shared_ptr<HOST_INFO>> CLUSTER_TOPOLOGY_INFO::get_writers() {
return writers;
}

std::vector<std::shared_ptr<HOST_INFO>> CLUSTER_TOPOLOGY_INFO::get_instances() {
std::vector instances(writers);
instances.insert(instances.end(), readers.begin(), readers.end());

return instances;
}

std::shared_ptr<HOST_INFO> CLUSTER_TOPOLOGY_INFO::get_last_used_reader() {
return last_used_reader;
}
Expand Down
2 changes: 2 additions & 0 deletions driver/cluster_topology_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class CLUSTER_TOPOLOGY_INFO {
virtual ~CLUSTER_TOPOLOGY_INFO();

void add_host(std::shared_ptr<HOST_INFO> host_info);
void remove_host(std::shared_ptr<HOST_INFO> host_info);
size_t total_hosts();
size_t num_readers(); // return number of readers in the cluster
std::time_t time_last_updated();
Expand All @@ -58,6 +59,7 @@ class CLUSTER_TOPOLOGY_INFO {
std::shared_ptr<HOST_INFO> get_reader(int i);
std::vector<std::shared_ptr<HOST_INFO>> get_writers();
std::vector<std::shared_ptr<HOST_INFO>> get_readers();
std::vector<std::shared_ptr<HOST_INFO>> get_instances();

private:
int current_reader = -1;
Expand Down
3 changes: 3 additions & 0 deletions driver/connection_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ CONNECTION_PROXY* CONNECTION_HANDLER::connect(std::shared_ptr<HOST_INFO> host_in
}

my_SQLFreeConnect(dbc_clone);
if (new_connection != nullptr) {
new_connection->set_dbc(dbc);
}

return new_connection;
}
Expand Down
Loading

0 comments on commit f65ffec

Please sign in to comment.