Skip to content

Commit

Permalink
chore: fix unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
karenc-bq committed Oct 22, 2024
1 parent f803204 commit f44af69
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 30 deletions.
1 change: 0 additions & 1 deletion common/lib/database_dialect/database_dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ export interface DatabaseDialect {
getServerVersionQuery(): string;
getDialectUpdateCandidates(): string[];
isDialect(targetClient: ClientWrapper): Promise<boolean>;
getAwsPoolClient(props: any): AwsPoolClient;
getHostListProvider(props: Map<string, any>, originalUrl: string, hostListProviderService: HostListProviderService): HostListProvider;
isClientValid(targetClient: ClientWrapper): Promise<boolean>;
getDatabaseType(): DatabaseType;
Expand Down
2 changes: 2 additions & 0 deletions common/lib/driver_dialect/driver_dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import { ClientWrapper } from "../client_wrapper";
import { AwsPoolConfig } from "../aws_pool_config";
import { AwsPoolClient } from "../aws_pool_client";

export interface DriverDialect {
getDialectName(): string;
Expand All @@ -24,4 +25,5 @@ export interface DriverDialect {
connect(targetClient: any): Promise<any>;
end(targetClient: ClientWrapper | undefined): Promise<void>;
preparePoolClientProperties(props: Map<string, any>, poolConfig: AwsPoolConfig | undefined): any;
getAwsPoolClient(props: any): AwsPoolClient;
}
2 changes: 1 addition & 1 deletion common/lib/internal_pooled_connection_provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ export class InternalPooledConnectionProvider implements PooledConnectionProvide
}
}

const dialect = pluginService.getDialect();
const dialect = pluginService.getDriverDialect();
const preparedConfig = dialect.preparePoolClientProperties(props, this._poolConfig);

this.internalPool = this.databasePools.computeIfAbsent(
Expand Down
6 changes: 6 additions & 0 deletions mysql/lib/dialect/mysql2_driver_dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import { ClientUtils } from "../../../common/lib/utils/client_utils";
import { createConnection, PoolOptions } from "mysql2/promise";
import { WrapperProperties } from "../../../common/lib/wrapper_property";
import { AwsPoolConfig } from "../../../common/lib/aws_pool_config";
import { AwsPoolClient } from "../../../common/lib/aws_pool_client";
import { AwsMysqlPoolClient } from "../mysql_pool_client";

export class MySQL2DriverDialect implements DriverDialect {
protected dialectName: string = this.constructor.name;
Expand Down Expand Up @@ -57,4 +59,8 @@ export class MySQL2DriverDialect implements DriverDialect {
finalPoolConfig.idleTimeout = poolConfig?.idleTimeoutMillis;
return finalPoolConfig;
}

getAwsPoolClient(props: PoolOptions): AwsPoolClient {
return new AwsMysqlPoolClient(props);
}
}
13 changes: 3 additions & 10 deletions mysql/lib/dialect/mysql_database_dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,15 @@
import { DatabaseDialect, DatabaseType } from "../../../common/lib/database_dialect/database_dialect";
import { HostListProviderService } from "../../../common/lib/host_list_provider_service";
import { HostListProvider } from "../../../common/lib/host_list_provider/host_list_provider";
import { ConnectionStringHostListProvider } from "../../../common/lib/host_list_provider/connection_string_host_list_provider";
import {
ConnectionStringHostListProvider
} from "../../../common/lib/host_list_provider/connection_string_host_list_provider";
import { AwsWrapperError } from "../../../common/lib/utils/errors";
import { DatabaseDialectCodes } from "../../../common/lib/database_dialect/database_dialect_codes";
import { TransactionIsolationLevel } from "../../../common/lib/utils/transaction_isolation_level";
import { ClientWrapper } from "../../../common/lib/client_wrapper";
import { ClientUtils } from "../../../common/lib/utils/client_utils";
import { FailoverRestriction } from "../../../common/lib/plugins/failover/failover_restriction";
import { AwsPoolClient } from "../../../common/lib/aws_pool_client";
import { AwsMysqlPoolClient } from "../mysql_pool_client";
import { AwsPoolConfig } from "../../../common/lib/aws_pool_config";
import { WrapperProperties } from "../../../common/lib/wrapper_property";
import { PoolOptions } from "mysql2/promise";

export class MySQLDatabaseDialect implements DatabaseDialect {
protected dialectName: string = this.constructor.name;
Expand Down Expand Up @@ -73,10 +70,6 @@ export class MySQLDatabaseDialect implements DatabaseDialect {
});
}

getAwsPoolClient(props: PoolOptions): AwsPoolClient {
return new AwsMysqlPoolClient(props);
}

getHostListProvider(props: Map<string, any>, originalUrl: string, hostListProviderService: HostListProviderService): HostListProvider {
return new ConnectionStringHostListProvider(props, originalUrl, this.getDefaultPort(), hostListProviderService);
}
Expand Down
6 changes: 6 additions & 0 deletions pg/lib/dialect/node_postgres_driver_dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import { ClientWrapper } from "../../../common/lib/client_wrapper";
import { Client, PoolConfig } from "pg";
import { WrapperProperties } from "../../../common/lib/wrapper_property";
import { AwsPoolConfig } from "../../../common/lib/aws_pool_config";
import { AwsPoolClient } from "../../../common/lib/aws_pool_client";
import { AwsPgPoolClient } from "../pg_pool_client";

export class NodePostgresDriverDialect implements DriverDialect {
protected dialectName: string = this.constructor.name;
Expand Down Expand Up @@ -58,4 +60,8 @@ export class NodePostgresDriverDialect implements DriverDialect {
finalPoolConfig.allowExitOnIdle = poolConfig?.allowExitOnIdle;
return finalPoolConfig;
}

getAwsPoolClient(props: PoolConfig): AwsPoolClient {
return new AwsPgPoolClient(props);
}
}
14 changes: 3 additions & 11 deletions pg/lib/dialect/pg_database_dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,14 @@
import { DatabaseDialect, DatabaseType } from "../../../common/lib/database_dialect/database_dialect";
import { HostListProviderService } from "../../../common/lib/host_list_provider_service";
import { HostListProvider } from "../../../common/lib/host_list_provider/host_list_provider";
import { ConnectionStringHostListProvider } from "../../../common/lib/host_list_provider/connection_string_host_list_provider";
import {
ConnectionStringHostListProvider
} from "../../../common/lib/host_list_provider/connection_string_host_list_provider";
import { AwsWrapperError } from "../../../common/lib/utils/errors";
import { DatabaseDialectCodes } from "../../../common/lib/database_dialect/database_dialect_codes";
import { TransactionIsolationLevel } from "../../../common/lib/utils/transaction_isolation_level";
import { ClientWrapper } from "../../../common/lib/client_wrapper";
import { FailoverRestriction } from "../../../common/lib/plugins/failover/failover_restriction";
import { AwsPoolClient } from "../../../common/lib/aws_pool_client";
import { AwsMysqlPoolClient } from "../../../mysql/lib/mysql_pool_client";
import { AwsPgPoolClient } from "../pg_pool_client";
import { AwsPoolConfig } from "../../../common/lib/aws_pool_config";
import { PoolClient, PoolConfig } from "pg";
import { WrapperProperties } from "../../../common/lib/wrapper_property";

export class PgDatabaseDialect implements DatabaseDialect {
protected dialectName: string = this.constructor.name;
Expand Down Expand Up @@ -72,10 +68,6 @@ export class PgDatabaseDialect implements DatabaseDialect {
});
}

getAwsPoolClient(props: PoolConfig): AwsPoolClient {
return new AwsPgPoolClient(props);
}

getHostListProvider(props: Map<string, any>, originalUrl: string, hostListProviderService: HostListProviderService): HostListProvider {
return new ConnectionStringHostListProvider(props, originalUrl, this.getDefaultPort(), hostListProviderService);
}
Expand Down
15 changes: 8 additions & 7 deletions tests/unit/internal_pool_connection_provider.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import { AwsMysqlPoolClient } from "../../mysql/lib/mysql_pool_client";
import { PoolKey } from "../../common/lib/utils/pool_key";
import { InternalPoolMapping } from "../../common/lib/utils/internal_pool_mapping";
import { SlidingExpirationCache } from "../../common/lib/utils/sliding_expiration_cache";
import { MySQL2DriverDialect } from "../../mysql/lib/dialect/mysql2_driver_dialect";

const internalPoolWithOneConnection = mock(AwsMysqlPoolClient);
const user1 = "user1";
Expand Down Expand Up @@ -67,6 +68,7 @@ const mockHostListProvider: HostListProvider = mock<HostListProvider>();
const mockClosedReaderClient: AwsClient = mock(AwsMySQLClient);
const mockClosedWriterClient: AwsClient = mock(AwsMySQLClient);
const mockDialect: MySQLDatabaseDialect = mock(MySQLDatabaseDialect);
const mockDriverDialect: MySQL2DriverDialect = mock(MySQL2DriverDialect);
const mockPoolConnection = mock(AwsMysqlPoolClient);
const mockAwsPoolClient = mock(AwsMysqlPoolClient);
const mockRdsUtils = mock(RdsUtils);
Expand All @@ -78,6 +80,9 @@ describe("reader write splitting test", () => {
when(mockPluginService.getHostListProvider()).thenReturn(instance(mockHostListProvider));
when(mockPluginService.getHosts()).thenReturn(defaultHosts);
when(mockPluginService.isInTransaction()).thenReturn(false);
when(mockPluginService.getDialect()).thenReturn(mockDialect);
when(mockPluginService.getDriverDialect()).thenReturn(mockDriverDialect);
when(mockDriverDialect.getAwsPoolClient(anything())).thenReturn(mockAwsPoolClient);
props.clear();
});

Expand Down Expand Up @@ -106,14 +111,12 @@ describe("reader write splitting test", () => {
when(mockRdsUtils.isRdsDns(anything())).thenReturn(null);
when(mockRdsUtils.isGreenInstance(anything())).thenReturn(null);
when(mockRdsUtils.isRdsInstance("instance1")).thenReturn(true);
when(mockPluginService.getDialect()).thenReturn(mockDialect);
when(mockDialect.getAwsPoolClient(anything())).thenReturn(mockAwsPoolClient);
const config = {
maxConnection: 10,
idleTimeoutMillis: 10000,
connectionTimeoutMillis: 10000
};
when(mockDialect.preparePoolClientProperties(anything(), anything())).thenReturn(config);
when(mockDriverDialect.preparePoolClientProperties(anything(), anything())).thenReturn(config);
const poolConfig: AwsPoolConfig = new AwsPoolConfig(config);

const provider = spy(new InternalPooledConnectionProvider(poolConfig));
Expand Down Expand Up @@ -146,7 +149,6 @@ describe("reader write splitting test", () => {
when(mockRdsUtils.isGreenInstance(anything())).thenReturn(null);
when(mockRdsUtils.isRdsInstance("instance1")).thenReturn(true);
when(mockPluginService.getDialect()).thenReturn(mockDialect);
when(mockDialect.getAwsPoolClient(anything())).thenReturn(mockAwsPoolClient);
const config = {
maxConnection: 10,
idleTimeoutMillis: 10000,
Expand All @@ -157,7 +159,7 @@ describe("reader write splitting test", () => {
return hostInfo.url + "someKey";
}
};
when(mockDialect.preparePoolClientProperties(anything(), anything())).thenReturn(config);
when(mockDriverDialect.preparePoolClientProperties(anything(), anything())).thenReturn(config);
const poolConfig: AwsPoolConfig = new AwsPoolConfig(config);

const provider = spy(new InternalPooledConnectionProvider(poolConfig, myKeyFunc));
Expand Down Expand Up @@ -197,8 +199,7 @@ describe("reader write splitting test", () => {
when(mockRdsUtils.isRdsDns(anything())).thenReturn(null);
when(mockRdsUtils.isGreenInstance(anything())).thenReturn(null);
when(mockRdsUtils.isRdsInstance("instance1")).thenReturn(true);
when(mockDialect.preparePoolClientProperties(anything(), anything())).thenReturn(props);
when(mockDialect.getAwsPoolClient(anything())).thenThrow(new Error("testError"));
when(mockDriverDialect.preparePoolClientProperties(anything(), anything())).thenReturn(props);

const provider = spy(new InternalPooledConnectionProvider(poolConfig));
const providerSpy = instance(provider);
Expand Down

0 comments on commit f44af69

Please sign in to comment.