Skip to content

Commit

Permalink
test: run custom endpoints integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
karenc-bq committed Jan 24, 2025
1 parent cf1f86f commit 7642d33
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 88 deletions.
1 change: 1 addition & 0 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on:
push:
branches:
- main
- test/custom-endpoints-integration-tests
pull_request:
branches:
- '*'
Expand Down
150 changes: 76 additions & 74 deletions driver/custom_endpoint_monitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,14 @@ CACHE_MAP<std::string, std::shared_ptr<CUSTOM_ENDPOINT_INFO>> CUSTOM_ENDPOINT_MO
CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr<TOPOLOGY_SERVICE> topology_service,
const std::string& custom_endpoint_host,
const std::string& endpoint_identifier, const std::string& region,
long long refresh_rate_nanos, bool enable_logging)
long long refresh_rate_nanos, ctpl::thread_pool& thread_pool,
bool enable_logging)
: topology_service(topology_service),
custom_endpoint_host(custom_endpoint_host),
endpoint_identifier(endpoint_identifier),
region(region),
refresh_rate_nanos(refresh_rate_nanos),
thread_pool(thread_pool),
enable_logging(enable_logging) {
if (enable_logging) {
this->logger = init_log_file();
Expand All @@ -67,19 +69,19 @@ CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr<TOPOLOGY_
CUSTOM_ENDPOINT_MONITOR::CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr<TOPOLOGY_SERVICE> topology_service,
const std::string& custom_endpoint_host,
const std::string& endpoint_identifier, const std::string& region,
long long refresh_rate_nanos, bool enable_logging,
std::shared_ptr<Aws::RDS::RDSClient> client)
long long refresh_rate_nanos, ctpl::thread_pool& thread_pool,
bool enable_logging, std::shared_ptr<Aws::RDS::RDSClient> client)
: topology_service(topology_service),
custom_endpoint_host(custom_endpoint_host),
endpoint_identifier(endpoint_identifier),
region(region),
refresh_rate_nanos(refresh_rate_nanos),
thread_pool(thread_pool),
enable_logging(enable_logging) {
if (enable_logging) {
this->logger = init_log_file();
}

thread_pool = std::thread(&CUSTOM_ENDPOINT_MONITOR::run, this);
this->run();
}
#endif

Expand All @@ -91,83 +93,84 @@ bool CUSTOM_ENDPOINT_MONITOR::has_custom_endpoint_info() const {
}

void CUSTOM_ENDPOINT_MONITOR::run() {
++SDK_HELPER;

Aws::RDS::RDSClientConfiguration client_config;
if (!region.empty()) {
client_config.region = region;
}

const Aws::RDS::RDSClient rds_client(Aws::Auth::DefaultAWSCredentialsProviderChain().GetAWSCredentials(),
client_config);
Aws::RDS::Model::Filter filter;
filter.SetName("db-cluster-endpoint-type");
filter.AddValues("custom");

Aws::RDS::Model::DescribeDBClusterEndpointsRequest request;
request.SetDBClusterEndpointIdentifier(this->endpoint_identifier);
// TODO: Investigate why filters returns `InvalidParameterCombination` error saying filter values are null.
// request.AddFilters(filter);

MYLOG_TRACE(this->logger, 0, "Starting custom endpoint monitor for '%s'", this->custom_endpoint_host.c_str());
thread_pool.resize(1);
thread_pool.push([=](int id) {
++SDK_HELPER;

try {
while (!should_stop) {
const std::chrono::time_point start = std::chrono::steady_clock::now();
const auto response = rds_client.DescribeDBClusterEndpoints(request);
const auto custom_endpoints = response.GetResult().GetDBClusterEndpoints();
if (custom_endpoints.size() != 1) {
MYLOG_TRACE(this->logger, 0,
"Unexpected number of custom endpoints with endpoint identifier '%s' in region '%s'. Expected 1 "
"custom endpoint, but found %d. Endpoints: %s",
endpoint_identifier.c_str(), region.c_str(), custom_endpoints.size(),
this->get_endpoints_as_string(custom_endpoints).c_str());

std::this_thread::sleep_for(std::chrono::nanoseconds(this->refresh_rate_nanos));
continue;
}
const std::shared_ptr<CUSTOM_ENDPOINT_INFO> endpoint_info =
CUSTOM_ENDPOINT_INFO::from_db_cluster_endpoint(custom_endpoints[0]);
const std::shared_ptr<CUSTOM_ENDPOINT_INFO> cache_endpoint_info =
custom_endpoint_cache.get(this->custom_endpoint_host, nullptr);
Aws::RDS::RDSClientConfiguration client_config;
if (!region.empty()) {
client_config.region = region;
}

if (cache_endpoint_info != nullptr && cache_endpoint_info == endpoint_info) {
const Aws::RDS::RDSClient rds_client(Aws::Auth::DefaultAWSCredentialsProviderChain().GetAWSCredentials(),
client_config);
Aws::RDS::Model::Filter filter;
filter.SetName("db-cluster-endpoint-type");
filter.AddValues("custom");

Aws::RDS::Model::DescribeDBClusterEndpointsRequest request;
request.SetDBClusterEndpointIdentifier(this->endpoint_identifier);
// TODO: Investigate why filters returns `InvalidParameterCombination` error saying filter values are null.
// request.AddFilters(filter);
try {
while (!should_stop) {
const std::chrono::time_point start = std::chrono::steady_clock::now();
const auto response = rds_client.DescribeDBClusterEndpoints(request);
const auto custom_endpoints = response.GetResult().GetDBClusterEndpoints();
if (custom_endpoints.size() != 1) {
MYLOG_TRACE(this->logger, 0,
"Unexpected number of custom endpoints with endpoint identifier '%s' in region '%s'. Expected 1 "
"custom endpoint, but found %d. Endpoints: %s",
endpoint_identifier.c_str(), region.c_str(), custom_endpoints.size(),
this->get_endpoints_as_string(custom_endpoints).c_str());

std::this_thread::sleep_for(std::chrono::nanoseconds(this->refresh_rate_nanos));
continue;
}
const std::shared_ptr<CUSTOM_ENDPOINT_INFO> endpoint_info =
CUSTOM_ENDPOINT_INFO::from_db_cluster_endpoint(custom_endpoints[0]);
const std::shared_ptr<CUSTOM_ENDPOINT_INFO> cache_endpoint_info =
custom_endpoint_cache.get(this->custom_endpoint_host, nullptr);

if (cache_endpoint_info != nullptr && cache_endpoint_info == endpoint_info) {
const long long elapsed_time =
std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::steady_clock::now() - start).count();
std::this_thread::sleep_for(
std::chrono::nanoseconds(std::max(static_cast<long long>(0), this->refresh_rate_nanos - elapsed_time)));
continue;
}

MYLOG_TRACE(this->logger, 0, "Detected change in custom endpoint info for '%s':\n{%s}",
custom_endpoint_host.c_str(), endpoint_info->to_string().c_str());

// The custom endpoint info has changed, so we need to update the set of allowed/blocked hosts.
std::shared_ptr<ALLOWED_AND_BLOCKED_HOSTS> allowed_and_blocked_hosts;
if (endpoint_info->get_member_list_type() == STATIC_LIST) {
allowed_and_blocked_hosts =
std::make_shared<ALLOWED_AND_BLOCKED_HOSTS>(endpoint_info->get_static_members(), std::set<std::string>());
} else {
allowed_and_blocked_hosts = std::make_shared<ALLOWED_AND_BLOCKED_HOSTS>(
std::set<std::string>(), endpoint_info->get_excluded_members());
}

this->topology_service->set_allowed_and_blocked_hosts(allowed_and_blocked_hosts);
custom_endpoint_cache.put(this->custom_endpoint_host, endpoint_info, CUSTOM_ENDPOINT_INFO_EXPIRATION_NANOS);
const long long elapsed_time =
std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::steady_clock::now() - start).count();
std::this_thread::sleep_for(
std::chrono::nanoseconds(std::max(static_cast<long long>(0), this->refresh_rate_nanos - elapsed_time)));
continue;
}

MYLOG_TRACE(this->logger, 0, "Detected change in custom endpoint info for '%s':\n{%s}",
custom_endpoint_host.c_str(), endpoint_info->to_string().c_str());

// The custom endpoint info has changed, so we need to update the set of allowed/blocked hosts.
std::shared_ptr<ALLOWED_AND_BLOCKED_HOSTS> allowed_and_blocked_hosts;
if (endpoint_info->get_member_list_type() == STATIC_LIST) {
allowed_and_blocked_hosts =
std::make_shared<ALLOWED_AND_BLOCKED_HOSTS>(endpoint_info->get_static_members(), std::set<std::string>());
} else {
allowed_and_blocked_hosts =
std::make_shared<ALLOWED_AND_BLOCKED_HOSTS>(std::set<std::string>(), endpoint_info->get_excluded_members());
}

this->topology_service->set_allowed_and_blocked_hosts(allowed_and_blocked_hosts);
custom_endpoint_cache.put(this->custom_endpoint_host, endpoint_info, CUSTOM_ENDPOINT_INFO_EXPIRATION_NANOS);
const long long elapsed_time =
std::chrono::duration_cast<std::chrono::nanoseconds>(std::chrono::steady_clock::now() - start).count();
std::this_thread::sleep_for(
std::chrono::nanoseconds(std::max(static_cast<long long>(0), this->refresh_rate_nanos - elapsed_time)));
--SDK_HELPER;
} catch (const std::exception& e) {
// Log and continue monitoring.
--SDK_HELPER;
MYLOG_TRACE(this->logger, 0, "Error while monitoring custom endpoint: %s", e.what());
}

--SDK_HELPER;
} catch (const std::exception& e) {
// Log and continue monitoring.
--SDK_HELPER;
MYLOG_TRACE(this->logger, 0, "Error while monitoring custom endpoint: %s", e.what());
}

should_stop = true;
should_stop = true;
});
}

std::string CUSTOM_ENDPOINT_MONITOR::get_endpoints_as_string(
Expand All @@ -191,9 +194,8 @@ std::string CUSTOM_ENDPOINT_MONITOR::get_endpoints_as_string(

void CUSTOM_ENDPOINT_MONITOR::stop() {
should_stop = true;
//thread_pool.stop(false);
//thread_pool.resize(0);
thread_pool.join();
thread_pool.stop(true);
thread_pool.resize(0);
custom_endpoint_cache.remove(this->custom_endpoint_host);
MYLOG_TRACE(this->logger, 0, "Stopped custom endpoint monitor for '%s'", this->custom_endpoint_host.c_str());
}
Expand Down
11 changes: 6 additions & 5 deletions driver/custom_endpoint_monitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,14 @@ class CUSTOM_ENDPOINT_MONITOR : public std::enable_shared_from_this<CUSTOM_ENDPO
public:
CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr<TOPOLOGY_SERVICE> topology_service,
const std::string& custom_endpoint_host, const std::string& endpoint_identifier,
const std::string& region, long long refresh_rate_nanos, bool enable_logging = false);
const std::string& region, long long refresh_rate_nanos, ctpl::thread_pool& thread_pool,
bool enable_logging = false);
#ifdef UNIT_TEST_BUILD
CUSTOM_ENDPOINT_MONITOR() = default;
CUSTOM_ENDPOINT_MONITOR(ctpl::thread_pool& pool): thread_pool(pool){};
CUSTOM_ENDPOINT_MONITOR(const std::shared_ptr<TOPOLOGY_SERVICE> topology_service,
const std::string& custom_endpoint_host, const std::string& endpoint_identifier,
const std::string& region, long long refresh_rate_nanos, bool enable_logging,
std::shared_ptr<Aws::RDS::RDSClient> client);
const std::string& region, long long refresh_rate_nanos, ctpl::thread_pool& thread_pool,
bool enable_logging, std::shared_ptr<Aws::RDS::RDSClient> client);
#endif

static bool should_dispose();
Expand All @@ -66,7 +67,7 @@ class CUSTOM_ENDPOINT_MONITOR : public std::enable_shared_from_this<CUSTOM_ENDPO
long long refresh_rate_nanos;
bool enable_logging;
std::shared_ptr<FILE> logger;
std::thread thread_pool;
ctpl::thread_pool& thread_pool;
bool should_stop;
std::shared_ptr<TOPOLOGY_SERVICE> topology_service;

Expand Down
3 changes: 2 additions & 1 deletion driver/custom_endpoint_proxy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ void CUSTOM_ENDPOINT_PROXY::wait_for_custom_endpoint_info(std::shared_ptr<CUSTOM
std::shared_ptr<CUSTOM_ENDPOINT_MONITOR> CUSTOM_ENDPOINT_PROXY::create_custom_endpoint_monitor(
const long long refresh_rate_nanos) {
return std::make_shared<CUSTOM_ENDPOINT_MONITOR>(this->topology_service, this->custom_endpoint_host,
this->custom_endpoint_id, this->region, refresh_rate_nanos);
this->custom_endpoint_id, this->region, refresh_rate_nanos,
this->dbc->env->custom_endpoint_thread_pool);
}

std::shared_ptr<CUSTOM_ENDPOINT_MONITOR> CUSTOM_ENDPOINT_PROXY::create_monitor_if_absent(DataSource* ds) {
Expand Down
1 change: 1 addition & 0 deletions driver/driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ struct ENV
MYERROR error;
std::mutex lock;
ctpl::thread_pool failover_thread_pool;
ctpl::thread_pool custom_endpoint_thread_pool;

ENV(SQLINTEGER ver) : odbc_ver(ver)
{}
Expand Down
4 changes: 2 additions & 2 deletions driver/sliding_expiration_cache_with_clean_up_thread.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ class SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD : public SLIDING_EXPIRATION_
SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD() = default;
SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD(std::shared_ptr<SHOULD_DISPOSE_FUNC<V>> should_dispose_func,
std::shared_ptr<ITEM_DISPOSAL_FUNC<V>> item_disposal_func)
: SLIDING_EXPIRATION_CACHE(std::move(should_dispose_func), std::move(item_disposal_func)){};
: SLIDING_EXPIRATION_CACHE<K, V>(std::move(should_dispose_func), std::move(item_disposal_func)){};
SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD(std::shared_ptr<SHOULD_DISPOSE_FUNC<V>> should_dispose_func,
std::shared_ptr<ITEM_DISPOSAL_FUNC<V>> item_disposal_func,
long long clean_up_interval_nanos)
: SLIDING_EXPIRATION_CACHE(std::move(should_dispose_func), std::move(item_disposal_func),
: SLIDING_EXPIRATION_CACHE<K, V>(std::move(should_dispose_func), std::move(item_disposal_func),
clean_up_interval_nanos){};
~SLIDING_EXPIRATION_CACHE_WITH_CLEAN_UP_THREAD() = default;

Expand Down
7 changes: 1 addition & 6 deletions integration/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,9 @@ set(TEST_SOURCES
integration_test_utils.cc
integration_test_utils.h

base_failover_integration_test.cc
connection_string_builder_test.cc)
base_failover_integration_test.cc)
set(INTEGRATION_TESTS
custom_endpoint_integration_test.cc
iam_authentication_integration_test.cc
secrets_manager_integration_test.cc
network_failover_integration_test.cc
failover_integration_test.cc
)

if(NOT ENABLE_PERFORMANCE_TESTS)
Expand Down
14 changes: 14 additions & 0 deletions integration/custom_endpoint_integration_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,21 +151,35 @@ TEST_F(CustomEndpointIntegrationTest, test_CustomeEndpointFailover) {
.getString();
SQLCHAR conn_out[4096] = "\0";
SQLSMALLINT len;
std::cout << "Establishing connection to " << connection_string.c_str() << std::endl;
EXPECT_EQ(SQL_SUCCESS, SQLDriverConnect(dbc, nullptr, AS_SQLCHAR(connection_string.c_str()), SQL_NTS, conn_out,
MAX_NAME_LEN, &len, SQL_DRIVER_NOPROMPT));

std::cout << "Established connection to " << connection_string.c_str() << std::endl;
const std::vector<std::string>& endpoint_members = endpoint_info.GetStaticMembers();
const std::string current_connection_id = query_instance_id(dbc);

EXPECT_NE(std::find(endpoint_members.begin(), endpoint_members.end(), current_connection_id), endpoint_members.end());

std::cout << "Triggering failover from " << writer_id.c_str() << " to " << current_connection_id.c_str() << std::endl;
failover_cluster_and_wait_until_writer_changed(
rds_client, cluster_id, writer_id, current_connection_id == writer_id ? target_writer_id : current_connection_id);

assert_query_failed(dbc, SERVER_ID_QUERY, ERROR_COMM_LINK_CHANGED);

std::cout << "Failover triggered" << std::endl;

const std::string new_connection_id = query_instance_id(dbc);

std::cout << "New connection ID " << new_connection_id.c_str() << std::endl;

EXPECT_NE(std::find(endpoint_members.begin(), endpoint_members.end(), new_connection_id), endpoint_members.end());
EXPECT_EQ(SQL_SUCCESS, SQLDisconnect(dbc));

if (nullptr != dbc) {
SQLFreeHandle(SQL_HANDLE_DBC, dbc);
}
if (nullptr != env) {
SQLFreeHandle(SQL_HANDLE_ENV, env);
}
}

0 comments on commit 7642d33

Please sign in to comment.