Skip to content

Commit

Permalink
feat: session state transfer (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
crystall-bitquill authored May 24, 2024
1 parent 0e1f847 commit e22b5dd
Show file tree
Hide file tree
Showing 25 changed files with 1,617 additions and 56 deletions.
24 changes: 23 additions & 1 deletion common/lib/aws_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ export abstract class AwsClient extends EventEmitter {
private _config: any;
protected isConnected: boolean = false;
protected _isReadOnly: boolean = false;
protected _isAutoCommit: boolean = true;
protected _catalog: string = "";
protected _schema: string = "";
protected _isolationLevel: number = 0;
private readonly _properties: Map<string, any>;
private _targetClient: any;
protected _errorHandler: ErrorHandler;
Expand Down Expand Up @@ -129,9 +133,27 @@ export abstract class AwsClient extends EventEmitter {

abstract isReadOnly(): boolean;

abstract setAutoCommit(autoCommit: boolean): Promise<any | void>;

abstract getAutoCommit(): boolean;

abstract setTransactionIsolation(transactionIsolation: number): Promise<any | void>;

abstract getTransactionIsolation(): number;

abstract setSchema(schema: any): Promise<any | void>;

abstract getSchema(): string;

abstract setCatalog(catalog: string): Promise<any | void>;

abstract getCatalog(): string;

abstract end(): Promise<any>;

abstract rollback(): void;
abstract rollback(): Promise<any>;

abstract resetState(): void;

async isValid(): Promise<boolean> {
if (!this.targetClient) {
Expand Down
5 changes: 5 additions & 0 deletions common/lib/database_dialect/database_dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,9 @@ export interface DatabaseDialect {
getConnectFunc(targetClient: any): () => Promise<any>;
getDatabaseType(): DatabaseType;
getDialectName(): string;
doesStatementSetReadOnly(statement: string): boolean | undefined;
doesStatementSetTransactionIsolation(statement: string): number | undefined;
doesStatementSetAutoCommit(statement: string): boolean | undefined;
doesStatementSetSchema(statement: string): string | undefined;
doesStatementSetCatalog(statement: string): string | undefined;
}
4 changes: 2 additions & 2 deletions common/lib/driver_connection_provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ export class DriverConnectionProvider implements ConnectionProvider {
const newTargetClient = pluginService.createTargetClient(props);
const fixedConnFunc = pluginService.getConnectFunc(newTargetClient);
result = await fixedConnFunc();
pluginService.setCurrentClient(newTargetClient, connectionHostInfo);
await pluginService.setCurrentClient(newTargetClient, connectionHostInfo);
}

return result;
Expand All @@ -106,7 +106,7 @@ export class DriverConnectionProvider implements ConnectionProvider {
getHostInfoByStrategy(hosts: HostInfo[], role: HostRole, strategy: string, props?: Map<string, any>): HostInfo {
const acceptedStrategy = DriverConnectionProvider.acceptedStrategies.get(strategy);
if (!acceptedStrategy) {
throw new AwsWrapperError(Messages.get("ConnectionProvider.unsupportedHostSpecSelectorStrategy", strategy, "DriverConnectionProvider")); // TODO
throw new AwsWrapperError(Messages.get("ConnectionProvider.unsupportedHostInfoSelectorStrategy", strategy, "DriverConnectionProvider"));
}
return acceptedStrategy.getHost(hosts, role, props);
}
Expand Down
4 changes: 2 additions & 2 deletions common/lib/plugin_manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ export class PluginManager {
return this.executeWithSubscribedPlugins(
hostInfo,
props,
"connect",
(plugin, nextPluginFunc) => plugin.connect(hostInfo, props, isInitialConnection, nextPluginFunc),
"forceConnect",
(plugin, nextPluginFunc) => plugin.forceConnect(hostInfo, props, isInitialConnection, nextPluginFunc),
methodFunc
);
}
Expand Down
92 changes: 74 additions & 18 deletions common/lib/plugin_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,23 @@ import { WrapperProperties } from "./wrapper_property";
import { OldConnectionSuggestionAction } from "./old_connection_suggestion_action";
import { DatabaseDialectProvider } from "./database_dialect/database_dialect_provider";
import { DatabaseDialectManager } from "./database_dialect/database_dialect_manager";
import { SqlMethodUtils } from "./utils/sql_method_utils";
import { SessionStateService } from "./session_state_service";
import { SessionStateServiceImpl } from "./session_state_service_impl";

export class PluginService implements ErrorHandler, HostListProviderService {
private readonly _currentClient: AwsClient;
private _currentHostInfo?: HostInfo;
private _hostListProvider?: HostListProvider;
private _initialConnectionHostInfo?: HostInfo;
private _isInTransaction: boolean = false;
private readonly _props: Map<string, any>;
private pluginServiceManagerContainer: PluginServiceManagerContainer;
private _props: Map<string, any>;
protected hosts: HostInfo[] = [];
private dbDialectProvider: DatabaseDialectProvider;
private initialHost: string;
private dialect: DatabaseDialect;
protected readonly sessionStateService: SessionStateService;
protected static readonly hostAvailabilityExpiringCache: CacheMap<string, HostAvailability> = new CacheMap<string, HostAvailability>();

constructor(
Expand All @@ -60,6 +64,7 @@ export class PluginService implements ErrorHandler, HostListProviderService {
this._props = props;
this.dbDialectProvider = new DatabaseDialectManager(knownDialectsByCode, dbType, this._props);
this.initialHost = props.get(WrapperProperties.HOST.name);
this.sessionStateService = new SessionStateServiceImpl(this, this._props);
container.pluginService = this;

this.dialect = this.dbDialectProvider.getDialect(this._props);
Expand Down Expand Up @@ -131,10 +136,10 @@ export class PluginService implements ErrorHandler, HostListProviderService {
}

async forceRefreshHostList(): Promise<void>;
async forceRefreshHostList(client?: AwsClient): Promise<void>;
async forceRefreshHostList(client?: AwsClient): Promise<void> {
const updatedHostList = client
? await this.getHostListProvider()?.forceRefresh(client.targetClient)
async forceRefreshHostList(targetClient?: any): Promise<void>;
async forceRefreshHostList(targetClient?: any): Promise<void> {
const updatedHostList = targetClient
? await this.getHostListProvider()?.forceRefresh(targetClient)
: await this.getHostListProvider()?.forceRefresh();
if (updatedHostList && updatedHostList !== this.hosts) {
this.updateHostAvailability(updatedHostList);
Expand All @@ -143,9 +148,9 @@ export class PluginService implements ErrorHandler, HostListProviderService {
}

async refreshHostList(): Promise<void>;
async refreshHostList(client: AwsClient): Promise<void>;
async refreshHostList(client?: AwsClient): Promise<void> {
const updatedHostList = client ? await this.getHostListProvider()?.refresh(client.targetClient) : await this.getHostListProvider()?.refresh();
async refreshHostList(targetClient: any): Promise<void>;
async refreshHostList(targetClient?: any): Promise<void> {
const updatedHostList = targetClient ? await this.getHostListProvider()?.refresh(targetClient) : await this.getHostListProvider()?.refresh();
if (updatedHostList && updatedHostList !== this.hosts) {
this.updateHostAvailability(updatedHostList);
this.setHostList(this.hosts, updatedHostList);
Expand Down Expand Up @@ -257,25 +262,32 @@ export class PluginService implements ErrorHandler, HostListProviderService {
throw new AwsWrapperError("AwsClient is missing target client connect function."); // This should not be reached
}

// TODO: Add session state changes
async setCurrentClient(newClient: any, hostInfo: HostInfo): Promise<Set<HostChangeOptions>> {
if (this.getCurrentClient().targetClient === null) {
this.getCurrentClient().targetClient = newClient;
this._currentHostInfo = hostInfo;
this.sessionStateService.reset();
const changes = new Set<HostChangeOptions>([HostChangeOptions.INITIAL_CONNECTION]);

if (this.pluginServiceManagerContainer.pluginManager) {
await this.pluginServiceManagerContainer.pluginManager.notifyConnectionChanged(changes, null);
}

return changes;
} else {
if (this._currentHostInfo) {
const changes: Set<HostChangeOptions> = this.compare(this._currentHostInfo, hostInfo);

if (changes.size > 0) {
const oldClient: any = this.getCurrentClient().targetClient;
const isInTransaction = this.isInTransaction;
this.sessionStateService.begin();

try {
this.getCurrentClient().resetState();
this.getCurrentClient().targetClient = newClient;
this._currentHostInfo = hostInfo;
await this.sessionStateService.applyCurrentSessionState(this.getCurrentClient());
this.setInTransaction(false);

if (this.pluginServiceManagerContainer.pluginManager) {
Expand All @@ -288,7 +300,7 @@ export class PluginService implements ErrorHandler, HostListProviderService {
(await oldClient.isValid());
}
} finally {
/* empty */
this.sessionStateService.complete();
}
}
return changes;
Expand All @@ -311,16 +323,25 @@ export class PluginService implements ErrorHandler, HostListProviderService {
return this.getDialect().getConnectFunc(targetClient);
}

getSessionStateService() {
return this.sessionStateService;
}

async updateState(sql: string) {
this.updateInTransaction(sql);

const statements = SqlMethodUtils.parseMultiStatementQueries(sql);
await this.updateReadOnly(statements);
await this.updateAutoCommit(statements);
await this.updateCatalog(statements);
await this.updateSchema(statements);
await this.updateTransactionIsolation(statements);
}

updateInTransaction(sql: string) {
// TODO: revise with session state transfer
if (sql.toLowerCase().startsWith("start transaction") || sql.toLowerCase().startsWith("begin")) {
if (SqlMethodUtils.doesOpenTransaction(sql)) {
this.setInTransaction(true);
} else if (
sql.toLowerCase().startsWith("commit") ||
sql.toLowerCase().startsWith("rollback") ||
sql.toLowerCase().startsWith("end") ||
sql.toLowerCase().startsWith("abort")
) {
} else if (SqlMethodUtils.doesCloseTransaction(sql)) {
this.setInTransaction(false);
}
}
Expand All @@ -335,4 +356,39 @@ export class PluginService implements ErrorHandler, HostListProviderService {

this._hostListProvider = this.dialect.getHostListProvider(this._props, this._props.get(WrapperProperties.HOST.name), this);
}

private async updateReadOnly(statements: string[]) {
const updateReadOnly = SqlMethodUtils.doesSetReadOnly(statements, this.getDialect());
if (updateReadOnly !== undefined) {
await this.getCurrentClient().setReadOnly(updateReadOnly);
}
}

private async updateAutoCommit(statements: string[]) {
const updateAutoCommit = SqlMethodUtils.doesSetAutoCommit(statements, this.getDialect());
if (updateAutoCommit !== undefined) {
await this.getCurrentClient().setAutoCommit(updateAutoCommit);
}
}

private async updateCatalog(statements: string[]) {
const updateCatalog = SqlMethodUtils.doesSetCatalog(statements, this.getDialect());
if (updateCatalog !== undefined) {
await this.getCurrentClient().setCatalog(updateCatalog);
}
}

private async updateSchema(statements: string[]) {
const updateSchema = SqlMethodUtils.doesSetSchema(statements, this.getDialect());
if (updateSchema !== undefined) {
await this.getCurrentClient().setSchema(updateSchema);
}
}

private async updateTransactionIsolation(statements: string[]) {
const updateTransactionIsolation = SqlMethodUtils.doesSetTransactionIsolation(statements, this.getDialect());
if (updateTransactionIsolation !== undefined) {
await this.getCurrentClient().setTransactionIsolation(updateTransactionIsolation);
}
}
}
3 changes: 1 addition & 2 deletions common/lib/plugins/failover/failover_plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,7 @@ export class FailoverPlugin extends AbstractConnectionPlugin {
if (this.pluginService.isInTransaction()) {
this._isInTransaction = this.pluginService.isInTransaction();
try {
// TODO: rollback not implemented
client.rollback();
await client.rollback();
} catch (error) {
// swallow this error
}
Expand Down
10 changes: 5 additions & 5 deletions common/lib/plugins/failover/writer_failover_handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -113,20 +113,20 @@ export class ClusterAwareWriterFailoverHandler implements WriterFailoverHandler
}, this.maxFailoverTimeoutMs);
});

const task1 = reconnectToWriterHandlerTask.call();
const task2 = waitForNewWriterHandlerTask.call();
const taskA = reconnectToWriterHandlerTask.call();
const taskB = waitForNewWriterHandlerTask.call();

const failoverTask = Promise.any([task1, task2])
const failoverTask = Promise.any([taskA, taskB])
.then((result) => {
this.test = true;
if (result.isConnected || result.exception) {
return result;
}

if (reconnectToWriterHandlerTask.taskComplete) {
return task2;
return taskB;
} else if (waitForNewWriterHandlerTask.taskComplete) {
return task1;
return taskA;
}
return ClusterAwareWriterFailoverHandler.DEFAULT_RESULT;
})
Expand Down
Loading

0 comments on commit e22b5dd

Please sign in to comment.