Skip to content

Commit

Permalink
feat: dialect selection (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
crystall-bitquill authored May 16, 2024
1 parent 9abcfbe commit 0e1f847
Show file tree
Hide file tree
Showing 25 changed files with 723 additions and 67 deletions.
24 changes: 13 additions & 11 deletions common/lib/aws_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import { PluginService } from "./plugin_service";
import { HostInfo } from "./host_info";
import { WrapperProperties } from "./wrapper_property";
import { ErrorHandler } from "./error_handler";
import { DatabaseDialect } from "./database_dialect";
import { DatabaseDialect, DatabaseType } from "./database_dialect/database_dialect";
import { ConnectionUrlParser } from "./utils/connection_url_parser";
import { HostListProvider } from "./host_list_provider/host_list_provider";
import { PluginManager } from "./plugin_manager";
Expand All @@ -36,25 +36,29 @@ export abstract class AwsClient extends EventEmitter {
private readonly _properties: Map<string, any>;
private _targetClient: any;
protected _errorHandler: ErrorHandler;
protected _dialect: DatabaseDialect;
protected _createClientFunc?: (config: any) => any;
protected _connectFunc?: () => Promise<any>;
protected _connectionUrlParser: ConnectionUrlParser;

protected constructor(config: any, errorHandler: ErrorHandler, dialect: DatabaseDialect, parser: ConnectionUrlParser) {
protected constructor(
config: any,
errorHandler: ErrorHandler,
dbType: DatabaseType,
knownDialectsByCode: Map<string, DatabaseDialect>,
parser: ConnectionUrlParser
) {
super();
this._errorHandler = errorHandler;
this._connectionUrlParser = parser;

this._dialect = dialect;
this._properties = new Map<string, any>(Object.entries(config));

const defaultConnProvider = new DriverConnectionProvider();
const effectiveConnProvider = null;
// TODO: check for configuration profile to update the effectiveConnProvider

const container = new PluginServiceManagerContainer();
this.pluginService = new PluginService(container, this, this._properties);
this.pluginService = new PluginService(container, this, dbType, knownDialectsByCode, this.properties);
this.pluginManager = new PluginManager(container, this._properties, defaultConnProvider, effectiveConnProvider);

// TODO: properly set up host info
Expand All @@ -65,7 +69,9 @@ export abstract class AwsClient extends EventEmitter {
}

protected async internalConnect() {
const hostListProvider: HostListProvider = this.dialect.getHostListProvider(this._properties, this._properties.get("host"), this.pluginService);
const hostListProvider: HostListProvider = this.pluginService
.getDialect()
.getHostListProvider(this._properties, this._properties.get("host"), this.pluginService);
this.pluginService.setHostListProvider(hostListProvider);
const info = this.pluginService.getCurrentHostInfo();
if (info != null) {
Expand Down Expand Up @@ -109,10 +115,6 @@ export abstract class AwsClient extends EventEmitter {
return this._errorHandler;
}

get dialect(): DatabaseDialect {
return this._dialect;
}

get connectionUrlParser(): ConnectionUrlParser {
return this._connectionUrlParser;
}
Expand All @@ -135,6 +137,6 @@ export abstract class AwsClient extends EventEmitter {
if (!this.targetClient) {
return Promise.resolve(false);
}
return await this.dialect.isClientValid(this.targetClient);
return await this.pluginService.getDialect().isClientValid(this.targetClient);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,26 @@
limitations under the License.
*/

import { AwsClient } from "./aws_client";
import { HostListProvider } from "./host_list_provider/host_list_provider";
import { HostListProviderService } from "./host_list_provider_service";
import { AwsClient } from "../aws_client";
import { HostListProvider } from "../host_list_provider/host_list_provider";
import { HostListProviderService } from "../host_list_provider_service";

export enum DatabaseType {
MYSQL,
POSTGRES
}

export interface DatabaseDialect {
getConnectFunc(newTargetClient: AwsClient): () => Promise<any>;
tryClosingTargetClient(targetClient: any): Promise<void>;
isClientValid(targetClient: any): Promise<boolean>;
getDefaultPort(): number;
getHostAliasQuery(): string;
getHostAliasAndParseResults(client: AwsClient): Promise<string>;
getServerVersionQuery(): string;
getDialectUpdateCandidates(): string[];
isDialect<T>(conn: T): boolean;
isDialect<T>(targetClient: T): Promise<boolean>;
getHostListProvider(props: Map<string, any>, originalUrl: string, hostListProviderService: HostListProviderService): HostListProvider;
tryClosingTargetClient(targetClient: any): Promise<void>;
isClientValid(targetClient: any): Promise<boolean>;
getConnectFunc(targetClient: any): () => Promise<any>;
getDatabaseType(): DatabaseType;
getDialectName(): string;
}
25 changes: 25 additions & 0 deletions common/lib/database_dialect/database_dialect_codes.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

export class DatabaseDialectCodes {
static readonly AURORA_MYSQL: string = "aurora-mysql";
static readonly RDS_MYSQL: string = "rds-mysql";
static readonly MYSQL: string = "mysql";
static readonly AURORA_PG: string = "aurora-pg";
static readonly RDS_PG: string = "rds-pg";
static readonly PG: string = "pg";
static readonly CUSTOM: string = "custom";
}
180 changes: 180 additions & 0 deletions common/lib/database_dialect/database_dialect_manager.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/*
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

import { DatabaseDialectProvider } from "./database_dialect_provider";
import { DatabaseDialect, DatabaseType } from "./database_dialect";
import { DatabaseDialectCodes } from "./database_dialect_codes";
import { WrapperProperties } from "../wrapper_property";
import { AwsWrapperError } from "../utils/errors";
import { Messages } from "../utils/messages";
import { RdsUtils } from "../utils/rds_utils";
import { logger } from "../../logutils";
import { CacheMap } from "../utils/cache_map";

export class DatabaseDialectManager implements DatabaseDialectProvider {
/**
* In order to simplify dialect detection, there's an internal host-to-dialect cache.
* The cache contains host endpoints and identified dialect. Cache expiration time in
* milliseconds is defined by the variable below.
*/
private static readonly ENDPOINT_CACHE_EXPIRATION_MS = 86_400_000_000_000; // 24 hours
protected static readonly knownEndpointDialects: CacheMap<string, string> = new CacheMap();
protected readonly knownDialectsByCode: Map<string, DatabaseDialect>;

private static customDialect: DatabaseDialect | null = null;
private readonly rdsHelper: RdsUtils = new RdsUtils();
private readonly dbType;
private canUpdate: boolean = false;
private dialect: DatabaseDialect;
private dialectCode: string = "";

constructor(knownDialectsByCode: any, dbType: DatabaseType, props: Map<string, any>) {
this.knownDialectsByCode = knownDialectsByCode;
this.dbType = dbType;
this.dialect = this.getDialect(props);
}

static setCustomDialect(dialect: DatabaseDialect) {
DatabaseDialectManager.customDialect = dialect;
}

static resetCustomDialect() {
DatabaseDialectManager.customDialect = null;
}

static resetEndpointCache() {
DatabaseDialectManager.knownEndpointDialects.clear();
}

getDialect(props: Map<string, any>): DatabaseDialect {
if (this.dialect) {
return this.dialect;
}

this.canUpdate = false;

if (DatabaseDialectManager.customDialect) {
this.dialectCode = DatabaseDialectCodes.CUSTOM;
this.dialect = DatabaseDialectManager.customDialect;
this.logCurrentDialect();
return this.dialect;
}

const userDialectSetting = WrapperProperties.DIALECT.get(props);
const host = props.get(WrapperProperties.HOST.name);
const dialectCode = userDialectSetting ?? DatabaseDialectManager.knownEndpointDialects.get(host);

if (dialectCode) {
const userDialect = this.knownDialectsByCode.get(dialectCode);
if (userDialect) {
this.dialectCode = dialectCode;
this.dialect = userDialect;
this.logCurrentDialect();
return userDialect;
}
throw new AwsWrapperError(Messages.get("DialectManager.unknownDialectCode", dialectCode));
}

if (this.dbType === DatabaseType.MYSQL) {
const type = this.rdsHelper.identifyRdsType(host);
if (type.isRdsCluster) {
this.dialectCode = DatabaseDialectCodes.AURORA_MYSQL;
this.dialect = <DatabaseDialect>this.knownDialectsByCode.get(DatabaseDialectCodes.AURORA_MYSQL);
this.logCurrentDialect();
return this.dialect;
}

if (type.isRds) {
this.canUpdate = true;
this.dialectCode = DatabaseDialectCodes.RDS_MYSQL;
this.dialect = <DatabaseDialect>this.knownDialectsByCode.get(DatabaseDialectCodes.RDS_MYSQL);
this.logCurrentDialect();
return this.dialect;
}

this.canUpdate = true;
this.dialectCode = DatabaseDialectCodes.MYSQL;
this.dialect = <DatabaseDialect>this.knownDialectsByCode.get(DatabaseDialectCodes.MYSQL);
this.logCurrentDialect();
return this.dialect;
}

if (this.dbType === DatabaseType.POSTGRES) {
const type = this.rdsHelper.identifyRdsType(host);
if (type.isRdsCluster) {
this.dialectCode = DatabaseDialectCodes.AURORA_PG;
this.dialect = <DatabaseDialect>this.knownDialectsByCode.get(DatabaseDialectCodes.AURORA_PG);
this.logCurrentDialect();
return this.dialect;
}

if (type.isRds) {
this.canUpdate = true;
this.dialectCode = DatabaseDialectCodes.RDS_PG;
this.dialect = <DatabaseDialect>this.knownDialectsByCode.get(DatabaseDialectCodes.RDS_PG);
this.logCurrentDialect();
return this.dialect;
}

this.canUpdate = true;
this.dialectCode = DatabaseDialectCodes.PG;
this.dialect = <DatabaseDialect>this.knownDialectsByCode.get(DatabaseDialectCodes.PG);
this.logCurrentDialect();
return this.dialect;
}

throw new AwsWrapperError(Messages.get("DialectManager.getDialectError"));
}

async getDialectForUpdate(targetClient: any, originalHost: string, newHost: string): Promise<DatabaseDialect> {
if (!this.canUpdate) {
return this.dialect;
}

const dialectCandidates = this.dialect.getDialectUpdateCandidates();
if (dialectCandidates.length > 0) {
for (const dialectCandidateCode of dialectCandidates) {
const dialectCandidate = this.knownDialectsByCode.get(dialectCandidateCode);
if (!dialectCandidate) {
throw new AwsWrapperError(Messages.get("DialectManager.unknownDialectCode", dialectCandidateCode));
}

const isDialect = await dialectCandidate.isDialect(targetClient);
if (isDialect) {
this.canUpdate = false;
this.dialectCode = dialectCandidateCode;
this.dialect = dialectCandidate;

DatabaseDialectManager.knownEndpointDialects.put(originalHost, dialectCandidateCode, DatabaseDialectManager.ENDPOINT_CACHE_EXPIRATION_MS);
DatabaseDialectManager.knownEndpointDialects.put(newHost, dialectCandidateCode, DatabaseDialectManager.ENDPOINT_CACHE_EXPIRATION_MS);

this.logCurrentDialect();
return this.dialect;
}
}
}

DatabaseDialectManager.knownEndpointDialects.put(originalHost, this.dialectCode, DatabaseDialectManager.ENDPOINT_CACHE_EXPIRATION_MS);
DatabaseDialectManager.knownEndpointDialects.put(newHost, this.dialectCode, DatabaseDialectManager.ENDPOINT_CACHE_EXPIRATION_MS);

this.logCurrentDialect();
return this.dialect;
}

logCurrentDialect() {
logger.info(`Current dialect: ${this.dialectCode}, ${this.dialect.getDialectName()}, canUpdate: ${this.canUpdate}`);
}
}
22 changes: 22 additions & 0 deletions common/lib/database_dialect/database_dialect_provider.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

import { DatabaseDialect, DatabaseType } from "./database_dialect";

export interface DatabaseDialectProvider {
getDialect(props: Map<string, any>): DatabaseDialect;
getDialectForUpdate(targetClient: any, originalHost: string, newHost: string): Promise<DatabaseDialect>;
}
5 changes: 3 additions & 2 deletions common/lib/host_list_provider/host_list_provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import { AwsClient } from "../aws_client";
import { HostInfo } from "../host_info";
import { HostRole } from "../host_role";
import { DatabaseDialect } from "../database_dialect/database_dialect";

export interface DynamicHostListProvider extends HostListProvider {}

Expand All @@ -31,9 +32,9 @@ export interface HostListProvider {

forceRefresh(client: AwsClient): Promise<HostInfo[]>;

getHostRole(client: AwsClient): Promise<HostRole>;
getHostRole(client: AwsClient, dialect: DatabaseDialect): Promise<HostRole>;

identifyConnection(client: AwsClient): Promise<HostInfo | void | null>;
identifyConnection(client: AwsClient, dialect: DatabaseDialect): Promise<HostInfo | void | null>;

createHost(host: string, isWriter: boolean, weight: number, lastUpdateTime: number): HostInfo;
}
14 changes: 7 additions & 7 deletions common/lib/host_list_provider/rds_host_list_provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import { HostAvailability } from "../host_availability/host_availability";
import { CacheMap } from "../utils/cache_map";
import { logTopology } from "../utils/utils";
import { TopologyAwareDatabaseDialect } from "../topology_aware_database_dialect";
import { DatabaseDialect } from "../database_dialect";
import { DatabaseDialect } from "../database_dialect/database_dialect";

export class RdsHostListProvider implements DynamicHostListProvider {
private readonly hostListProviderService: HostListProviderService;
Expand Down Expand Up @@ -130,19 +130,19 @@ export class RdsHostListProvider implements DynamicHostListProvider {
throw new AwsWrapperError("Could not retrieve targetClient.");
}

async getHostRole(client: AwsClient): Promise<HostRole> {
if (!this.isTopologyAwareDatabaseDialect(client.dialect)) {
async getHostRole(client: AwsClient, dialect: DatabaseDialect): Promise<HostRole> {
if (!this.isTopologyAwareDatabaseDialect(dialect)) {
throw new TypeError(Messages.get("RdsHostListProvider.incorrectDialect"));
}

return client.dialect.getHostRole(client, this.properties);
return dialect.getHostRole(client, this.properties);
}

async identifyConnection(client: AwsClient): Promise<HostInfo | null> {
if (!this.isTopologyAwareDatabaseDialect(client.dialect)) {
async identifyConnection(client: AwsClient, dialect: DatabaseDialect): Promise<HostInfo | null> {
if (!this.isTopologyAwareDatabaseDialect(dialect)) {
throw new TypeError(Messages.get("RdsHostListProvider.incorrectDialect"));
}
const instanceName = await client.dialect.identifyConnection(client, this.properties);
const instanceName = await dialect.identifyConnection(client, this.properties);

return this.refresh().then((topology) => {
const matches = topology.filter((host) => host.hostId === instanceName);
Expand Down
Loading

0 comments on commit 0e1f847

Please sign in to comment.