diff --git a/mysql/lib/dialect/mysql2_driver_dialect.ts b/mysql/lib/dialect/mysql2_driver_dialect.ts index 7f019a73..65f7a18d 100644 --- a/mysql/lib/dialect/mysql2_driver_dialect.ts +++ b/mysql/lib/dialect/mysql2_driver_dialect.ts @@ -76,6 +76,10 @@ export class MySQL2DriverDialect implements DriverDialect { } setKeepAliveProperties(props: Map, keepAliveProps: any) { + if (keepAliveProps instanceof Map) { + keepAliveProps = Object.fromEntries(keepAliveProps); + } + if (keepAliveProps && keepAliveProps[MySQL2DriverDialect.KEEP_ALIVE_PROPERTY_NAME] !== undefined) { throw new UnsupportedMethodError("Keep alive configuration is not supported for MySQL2."); } diff --git a/pg/lib/dialect/node_postgres_driver_dialect.ts b/pg/lib/dialect/node_postgres_driver_dialect.ts index 47b643df..c0f678db 100644 --- a/pg/lib/dialect/node_postgres_driver_dialect.ts +++ b/pg/lib/dialect/node_postgres_driver_dialect.ts @@ -83,6 +83,10 @@ export class NodePostgresDriverDialect implements DriverDialect { return; } + if (keepAliveProps instanceof Map) { + keepAliveProps = Object.fromEntries(keepAliveProps); + } + const keepAlive = keepAliveProps[NodePostgresDriverDialect.KEEP_ALIVE_PROPERTY_NAME]; const keepAliveInitialDelayMillis = keepAliveProps[NodePostgresDriverDialect.KEEP_ALIVE_INITIAL_DELAY_MILLIS_PROPERTY_NAME]; diff --git a/tests/unit/aurora_connection_tracker.test.ts b/tests/unit/aurora_connection_tracker.test.ts index a14d6cf1..03cd313f 100644 --- a/tests/unit/aurora_connection_tracker.test.ts +++ b/tests/unit/aurora_connection_tracker.test.ts @@ -29,6 +29,8 @@ import { ClientWrapper } from "../../common/lib/client_wrapper"; import { HostInfo } from "../../common/lib/host_info"; import { MySQLClientWrapper } from "../../common/lib/mysql_client_wrapper"; import { jest } from "@jest/globals"; +import { DriverDialect } from "../../common/lib/driver_dialect/driver_dialect"; +import { MySQL2DriverDialect } from "../../mysql/lib/dialect/mysql2_driver_dialect"; const props = new Map(); const SQL_ARGS = ["sql"]; @@ -48,8 +50,9 @@ const mockRdsUtils = mock(RdsUtils); const mockClient = mock(AwsClient); const mockHostInfo = mock(HostInfo); +const mockDriverDialect: DriverDialect = mock(MySQL2DriverDialect); const mockClientInstance = instance(mockClient); -const mockClientWrapper: ClientWrapper = new MySQLClientWrapper(undefined, mockHostInfo, props); +const mockClientWrapper: ClientWrapper = new MySQLClientWrapper(undefined, mockHostInfo, props, mockDriverDialect); mockClientInstance.targetClient = mockClientWrapper; diff --git a/tests/unit/aurora_initial_connection_strategy_plugin.test.ts b/tests/unit/aurora_initial_connection_strategy_plugin.test.ts index bc9457c2..f07a1341 100644 --- a/tests/unit/aurora_initial_connection_strategy_plugin.test.ts +++ b/tests/unit/aurora_initial_connection_strategy_plugin.test.ts @@ -30,6 +30,8 @@ import { AwsWrapperError } from "../../common/lib/utils/errors"; import { MySQLClientWrapper } from "../../common/lib/mysql_client_wrapper"; import { jest } from "@jest/globals"; import { PgClientWrapper } from "../../common/lib/pg_client_wrapper"; +import { DriverDialect } from "../../common/lib/driver_dialect/driver_dialect"; +import { MySQL2DriverDialect } from "../../mysql/lib/dialect/mysql2_driver_dialect"; const mockPluginService = mock(PluginService); const mockHostListProviderService = mock(); @@ -48,6 +50,7 @@ const hostInfo = hostInfoBuilder.withHost("host").build(); const writerHostInfo = hostInfoBuilder.withHost("host").withRole(HostRole.WRITER).build(); const readerHostInfo = hostInfoBuilder.withHost("host").withHost(HostRole.READER).build(); +const mockDriverDialect: DriverDialect = mock(MySQL2DriverDialect); describe("Aurora initial connection strategy plugin", () => { let props: Map; @@ -62,8 +65,8 @@ describe("Aurora initial connection strategy plugin", () => { plugin.initHostProvider(hostInfo, props, instance(mockHostListProviderService), mockFunc); WrapperProperties.OPEN_CONNECTION_RETRY_TIMEOUT_MS.set(props, 1000); - writerClient = new MySQLClientWrapper(undefined, writerHostInfo, new Map()); - readerClient = new MySQLClientWrapper(undefined, readerHostInfo, new Map()); + writerClient = new MySQLClientWrapper(undefined, writerHostInfo, new Map(), mockDriverDialect); + readerClient = new MySQLClientWrapper(undefined, readerHostInfo, new Map(), mockDriverDialect); }); afterEach(() => { diff --git a/tests/unit/database_dialect.test.ts b/tests/unit/database_dialect.test.ts index 17084135..4b631dc0 100644 --- a/tests/unit/database_dialect.test.ts +++ b/tests/unit/database_dialect.test.ts @@ -278,8 +278,7 @@ describe("test database dialects", () => { databaseType, expectedDialect!.dialects, props, - mockDriverDialect, - null + mockDriverDialect ); await pluginService.updateDialect(mockClientWrapper); expect(pluginService.getDialect()).toBe(expectedDialectClass); diff --git a/tests/unit/driver_dialect.test.ts b/tests/unit/driver_dialect.test.ts new file mode 100644 index 00000000..7b4680ae --- /dev/null +++ b/tests/unit/driver_dialect.test.ts @@ -0,0 +1,73 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { mock } from "ts-mockito"; +import { MySQL2DriverDialect } from "../../mysql/lib/dialect/mysql2_driver_dialect"; +import { HostInfo } from "../../common/lib/host_info"; +import { UnsupportedMethodError } from "../../common/lib/utils/errors"; +import { NodePostgresDriverDialect } from "../../pg/lib/dialect/node_postgres_driver_dialect"; + +const mockHostInfo: HostInfo = mock(HostInfo); +const emptyProps: Map = new Map(); + +describe("driverDialectTest", () => { + it("test_connectWithKeepAliveProps_MySQL_shouldThrow", async () => { + const keepAliveProps = new Map([ + ["keepAlive", true], + ["keepAliveInitialDelayMillis", 1234] + ]); + + const props = new Map([["wrapperKeepAliveProperties", keepAliveProps]]); + + const dialect = new MySQL2DriverDialect(); + const unsupportedError = new UnsupportedMethodError("Keep alive configuration is not supported for MySQL2."); + + await expect(dialect.connect(mockHostInfo, props)).rejects.toThrow(unsupportedError); + + const keepAliveObj = { + keepAlive: true, + keepAliveInitialDelayMillis: 1234 + }; + + const propsWithObj = new Map([["wrapperKeepAliveProperties", keepAliveObj]]); + + await expect(dialect.connect(mockHostInfo, propsWithObj)).rejects.toThrow(unsupportedError); + }); + + it("test_connectWithKeepAliveProps_PG_shouldSucceed", async () => { + const keepAliveMap = new Map([ + ["keepAlive", true], + ["keepAliveInitialDelayMillis", 1234] + ]); + + const dialect = new NodePostgresDriverDialect(); + + dialect.setKeepAliveProperties(emptyProps, keepAliveMap); + expect(emptyProps.get("keepAlive")).toBe(true); + expect(emptyProps.get("keepAliveInitialDelayMillis")).toBe(1234); + + emptyProps.clear(); + + const keepAliveObj = { + keepAlive: true, + keepAliveInitialDelayMillis: 1234 + }; + + dialect.setKeepAliveProperties(emptyProps, keepAliveObj); + expect(emptyProps.get("keepAlive")).toBe(true); + expect(emptyProps.get("keepAliveInitialDelayMillis")).toBe(1234); + }); +}); diff --git a/tests/unit/failover_plugin.test.ts b/tests/unit/failover_plugin.test.ts index 4ab1d414..02b1ffaf 100644 --- a/tests/unit/failover_plugin.test.ts +++ b/tests/unit/failover_plugin.test.ts @@ -38,6 +38,8 @@ import { Messages } from "../../common/lib/utils/messages"; import { HostChangeOptions } from "../../common/lib/host_change_options"; import { NullTelemetryFactory } from "../../common/lib/utils/telemetry/null_telemetry_factory"; import { MySQLClientWrapper } from "../../common/lib/mysql_client_wrapper"; +import { DriverDialect } from "../../common/lib/driver_dialect/driver_dialect"; +import { MySQL2DriverDialect } from "../../mysql/lib/dialect/mysql2_driver_dialect"; const builder = new HostInfoBuilder({ hostAvailabilityStrategy: new SimpleHostAvailabilityStrategy() }); @@ -52,8 +54,9 @@ let mockWriterFailoverHandlerInstance; const mockWriterFailoverHandler: ClusterAwareWriterFailoverHandler = mock(ClusterAwareWriterFailoverHandler); const mockReaderResult: ReaderFailoverResult = mock(ReaderFailoverResult); const mockWriterResult: WriterFailoverResult = mock(WriterFailoverResult); +const mockDriverDialect: DriverDialect = mock(MySQL2DriverDialect); -const mockClientWrapper = new MySQLClientWrapper(undefined, mockHostInfo, new Map()); +const mockClientWrapper = new MySQLClientWrapper(undefined, mockHostInfo, new Map(), mockDriverDialect); const properties: Map = new Map(); diff --git a/tests/unit/writer_failover_handler.test.ts b/tests/unit/writer_failover_handler.test.ts index 9543dd90..5c484180 100644 --- a/tests/unit/writer_failover_handler.test.ts +++ b/tests/unit/writer_failover_handler.test.ts @@ -28,6 +28,7 @@ import { WriterFailoverResult } from "../../common/lib/plugins/failover/writer_f import { ClientWrapper } from "../../common/lib/client_wrapper"; import { PgDatabaseDialect } from "../../pg/lib/dialect/pg_database_dialect"; import { MySQLClientWrapper } from "../../common/lib/mysql_client_wrapper"; +import { DriverDialect } from "../../common/lib/driver_dialect/driver_dialect"; import { MySQL2DriverDialect } from "../../mysql/lib/dialect/mysql2_driver_dialect"; const builder = new HostInfoBuilder({ hostAvailabilityStrategy: new SimpleHostAvailabilityStrategy() }); @@ -45,6 +46,7 @@ const mockClient = mock(AwsPGClient); // Using AwsPGClient in order to have abst const mockClientInstance = instance(mockClient); const mockPluginService = mock(PluginService); const mockReaderFailover = mock(ClusterAwareReaderFailoverHandler); +const mockDriverDialect: DriverDialect = mock(MySQL2DriverDialect); const mockTargetClient = { client: 123 }; const mockClientWrapper: ClientWrapper = new MySQLClientWrapper(