Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: custom endpoint support #218

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions driver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})
auth_util.cc
aws_sdk_helper.cc
base_metrics_holder.cc
cache_map.cc
catalog.cc
catalog_no_i_s.cc
cluster_topology_info.cc
Expand All @@ -72,6 +73,9 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})
connect.cc
connection_handler.cc
connection_proxy.cc
custom_endpoint_proxy.cc
custom_endpoint_info.cc
custom_endpoint_monitor.cc
cursor.cc
desc.cc
dll.cc
Expand Down Expand Up @@ -131,9 +135,11 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})
CONFIGURE_FILE(${CMAKE_SOURCE_DIR}/driver/driver.rc.cmake ${CMAKE_SOURCE_DIR}/driver/driver${CONNECTOR_DRIVER_TYPE_SHORT}.rc @ONLY)
SET(DRIVER_SRCS ${DRIVER_SRCS} driver${CONNECTOR_DRIVER_TYPE_SHORT}.def driver${CONNECTOR_DRIVER_TYPE_SHORT}.rc
adfs_proxy.h
allowed_and_blocked_hosts.h
auth_util.h
aws_sdk_helper.h
base_metrics_holder.h
cache_map.h
catalog.h
cluster_aware_hit_metrics_holder.h
cluster_aware_metrics_container.h
Expand All @@ -142,6 +148,9 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})
cluster_topology_info.h
connection_handler.h
connection_proxy.h
custom_endpoint_proxy.h
custom_endpoint_info.h
custom_endpoint_monitor.h
driver.h
efm_proxy.h
error.h
Expand Down
74 changes: 74 additions & 0 deletions driver/allowed_and_blocked_hosts.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License, version 2.0
// (GPLv2), as published by the Free Software Foundation, with the
// following additional permissions:
//
// This program is distributed with certain software that is licensed
// under separate terms, as designated in a particular file or component
// or in the license documentation. Without limiting your rights under
// the GPLv2, the authors of this program hereby grant you an additional
// permission to link the program and your derivative works with the
// separately licensed software that they have included with the program.
//
// Without limiting the foregoing grant of rights under the GPLv2 and
// additional permission as to separately licensed software, this
// program is also subject to the Universal FOSS Exception, version 1.0,
// a copy of which can be found along with its FAQ at
// http://oss.oracle.com/licenses/universal-foss-exception.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
// See the GNU General Public License, version 2.0, for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see
// http://www.gnu.org/licenses/gpl-2.0.html.

#ifndef __ALLOWED_AND_BLOCKED_HOSTS__
#define __ALLOWED_AND_BLOCKED_HOSTS__

#include <set>
#include <string>

/**
* Represents the allowed and blocked hosts for connections.
*/
class ALLOWED_AND_BLOCKED_HOSTS {
public:
/**
* Constructs an AllowedAndBlockedHosts instance with the specified allowed and blocked host IDs.
* @param allowed_host_ids The set of allowed host IDs for connections. If null or empty, all host IDs that are not in
* `blocked_host_ids` are allowed.
* @param blocked_host_ids The set of blocked host IDs for connections. If null or empty, all host IDs in
* `allowed_host_ids` are allowed. If `allowed_host_ids` is also null or empty, there
* are no restrictions on which hosts are allowed.
*/
ALLOWED_AND_BLOCKED_HOSTS(const std::set<std::string>& allowed_host_ids,
const std::set<std::string>& blocked_host_ids)
: allowed_host_ids(allowed_host_ids), blocked_host_ids(blocked_host_ids){};

/**
* Returns the set of allowed host IDs for connections. If null or empty, all host IDs that are not in
* `blocked_host_ids` are allowed.
*
* @return the set of allowed host IDs for connections.
*/
std::set<std::string> get_allowed_host_ids() { return this->allowed_host_ids; };

/**
* Returns the set of blocked host IDs for connections. If null or empty, all host IDs in `allowed_host_ids`
* are allowed. If `allowed_host_ids` is also null or empty, there are no restrictions on which hosts are allowed.
*
* @return the set of blocked host IDs for connections.
*/
std::set<std::string> get_blocked_host_ids() { return this->blocked_host_ids; };

private:
std::set<std::string> allowed_host_ids;
std::set<std::string> blocked_host_ids;
};

#endif
2 changes: 1 addition & 1 deletion driver/auth_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ std::pair<std::string, bool> AUTH_UTIL::get_auth_token(std::unordered_map<std::s
}

std::string auth_token;
const std::string cache_key = this->build_cache_key(host, region, port, user);
const std::string cache_key = build_cache_key(host, region, port, user);
bool using_cached_token = false;

{
Expand Down
95 changes: 95 additions & 0 deletions driver/cache_map.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License, version 2.0
// (GPLv2), as published by the Free Software Foundation, with the
// following additional permissions:
//
// This program is distributed with certain software that is licensed
// under separate terms, as designated in a particular file or component
// or in the license documentation. Without limiting your rights under
// the GPLv2, the authors of this program hereby grant you an additional
// permission to link the program and your derivative works with the
// separately licensed software that they have included with the program.
//
// Without limiting the foregoing grant of rights under the GPLv2 and
// additional permission as to separately licensed software, this
// program is also subject to the Universal FOSS Exception, version 1.0,
// a copy of which can be found along with its FAQ at
// http://oss.oracle.com/licenses/universal-foss-exception.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
// See the GNU General Public License, version 2.0, for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see
// http://www.gnu.org/licenses/gpl-2.0.html.

#include "cache_map.h"

#include <utility>

#include "custom_endpoint_info.h"

template <class K, class V>
void CACHE_MAP<K, V>::put(K key, V value, long long item_expiration_nanos) {
this->cache[key] = std::make_shared<CACHE_ITEM>(
value, std::chrono::steady_clock::now() + std::chrono::nanoseconds(item_expiration_nanos));
this->clean_up();
}

template <class K, class V>
V CACHE_MAP<K, V>::get(K key, V default_value) {
if (cache.count(key) > 0 && !cache[key]->is_expired()) {
return this->cache[key]->item;
}
return default_value;
}

template <class K, class V>
V CACHE_MAP<K, V>::get(K key, V default_value, long long item_expiration_nanos) {
if (cache.count(key) == 0 || this->cache[key]->is_expired()) {
this->put(key, std::move(default_value), item_expiration_nanos);
}
return this->cache[key]->item;
}

template <class K, class V>
void CACHE_MAP<K, V>::remove(K key) {
if (this->cache.count(key)) {
this->cache.erase(key);
}
this->clean_up();
}

template <class K, class V>
int CACHE_MAP<K, V>::size() {
return this->cache.size();
}

template <class K, class V>
void CACHE_MAP<K, V>::clear() {
this->cache.clear();
}

template <class K, class V>
void CACHE_MAP<K, V>::clean_up() {
if (std::chrono::steady_clock::now() > this->clean_up_time_nanos.load()) {
this->clean_up_time_nanos =
std::chrono::steady_clock::now() + std::chrono::nanoseconds(this->clean_up_time_interval_nanos);
std::vector<K> keys;
keys.reserve(this->cache.size());
for (auto& [key, cache_item] : this->cache) {
keys.push_back(key);
}
for (const auto& key : keys) {
if (this->cache[key]->is_expired()) {
this->cache.erase(key);
}
}
}
}

template class CACHE_MAP<std::string, std::shared_ptr<CUSTOM_ENDPOINT_INFO>>;
73 changes: 73 additions & 0 deletions driver/cache_map.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License, version 2.0
// (GPLv2), as published by the Free Software Foundation, with the
// following additional permissions:
//
// This program is distributed with certain software that is licensed
// under separate terms, as designated in a particular file or component
// or in the license documentation. Without limiting your rights under
// the GPLv2, the authors of this program hereby grant you an additional
// permission to link the program and your derivative works with the
// separately licensed software that they have included with the program.
//
// Without limiting the foregoing grant of rights under the GPLv2 and
// additional permission as to separately licensed software, this
// program is also subject to the Universal FOSS Exception, version 1.0,
// a copy of which can be found along with its FAQ at
// http://oss.oracle.com/licenses/universal-foss-exception.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
// See the GNU General Public License, version 2.0, for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see
// http://www.gnu.org/licenses/gpl-2.0.html.

#ifndef __CACHE_MAP__
#define __CACHE_MAP__

#include <chrono>
#include <memory>
#include <unordered_map>

template <class K, class V>
class CACHE_MAP {
public:
class CACHE_ITEM {
public:
CACHE_ITEM() = default;
CACHE_ITEM(V item, std::chrono::steady_clock::time_point expiration_time)
: item(item), expiration_time(expiration_time){};
~CACHE_ITEM() = default;
V item;

bool is_expired() { return std::chrono::steady_clock::now() > this->expiration_time; }

private:
std::chrono::steady_clock::time_point expiration_time;
};

CACHE_MAP() = default;
~CACHE_MAP() = default;

void put(K key, V value, long long item_expiration_nanos);
V get(K key, V default_value);
V get(K key, V default_value, long long item_expiration_nanos);
void remove(K key);
int size();
void clear();

protected:
void clean_up();
const long long clean_up_time_interval_nanos = 60000000000; // 10 minute
std::atomic<std::chrono::steady_clock::time_point> clean_up_time_nanos;

private:
std::unordered_map<K, std::shared_ptr<CACHE_ITEM>> cache;
};

#endif
84 changes: 84 additions & 0 deletions driver/custom_endpoint_info.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// This program is free software {} you can redistribute it and/or modify
// it under the terms of the GNU General Public License, version 2.0
// (GPLv2), as published by the Free Software Foundation, with the
// following additional permissions:
//
// This program is distributed with certain software that is licensed
// under separate terms, as designated in a particular file or component
// or in the license documentation. Without limiting your rights under
// the GPLv2, the authors of this program hereby grant you an additional
// permission to link the program and your derivative works with the
// separately licensed software that they have included with the program.
//
// Without limiting the foregoing grant of rights under the GPLv2 and
// additional permission as to separately licensed software, this
// program is also subject to the Universal FOSS Exception, version 1.0,
// a copy of which can be found along with its FAQ at
// http://oss.oracle.com/licenses/universal-foss-exception.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY {} without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
// See the GNU General Public License, version 2.0, for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see
// http://www.gnu.org/licenses/gpl-2.0.html.

#include "custom_endpoint_info.h"

std::shared_ptr<CUSTOM_ENDPOINT_INFO> CUSTOM_ENDPOINT_INFO::from_db_cluster_endpoint(
const Aws::RDS::Model::DBClusterEndpoint& response_endpoint_info) {
std::vector<std::string> members;
MEMBERS_LIST_TYPE members_list_type;

if (response_endpoint_info.StaticMembersHasBeenSet()) {
members = response_endpoint_info.GetStaticMembers();
members_list_type = STATIC_LIST;
} else {
members = response_endpoint_info.GetExcludedMembers();
members_list_type = EXCLUSION_LIST;
}

std::set members_set(members.begin(), members.end());

return std::make_shared<CUSTOM_ENDPOINT_INFO>(
response_endpoint_info.GetDBClusterEndpointIdentifier(), response_endpoint_info.GetDBClusterIdentifier(),
response_endpoint_info.GetEndpoint(),
CUSTOM_ENDPOINT_INFO::get_role_type(response_endpoint_info.GetCustomEndpointType()), members_set,
members_list_type);
}

std::set<std::string> CUSTOM_ENDPOINT_INFO::get_excluded_members() const {
if (this->member_list_type == EXCLUSION_LIST) {
return members;
}

return std::set<std::string>();
}

std::set<std::string> CUSTOM_ENDPOINT_INFO::get_static_members() const {
if (this->member_list_type == STATIC_LIST) {
return members;
}

return std::set<std::string>();
}

bool operator==(const CUSTOM_ENDPOINT_INFO& current, const CUSTOM_ENDPOINT_INFO& other) {
return current.endpoint_identifier == other.endpoint_identifier &&
current.cluster_identifier == other.cluster_identifier && current.url == other.url &&
current.role_type == other.role_type &&
current.member_list_type == other.member_list_type;
}

CUSTOM_ENDPOINT_ROLE_TYPE CUSTOM_ENDPOINT_INFO::get_role_type(const Aws::String& role_type) {
auto it = CUSTOM_ENDPOINT_ROLE_TYPE_MAP.find(role_type);
if (it != CUSTOM_ENDPOINT_ROLE_TYPE_MAP.end()) {
return it->second;
}

throw std::invalid_argument("Invalid role type for custom endpoint, this should not have happened.");
}
Loading
Loading