Skip to content

Commit

Permalink
feat: impl PoolWaitTimeoutError (#120)
Browse files Browse the repository at this point in the history
Throw error when get connection wait time great than poolWaitTimeout.
  • Loading branch information
killagu authored Aug 12, 2024
1 parent 855094e commit 7a355d3
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 10 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ const db = new RDSClient({
// connectionStorage: new AsyncLocalStorage(),
// If create multiple RDSClient instances with the same connectionStorage, use this key to distinguish between the instances
// connectionStorageKey: 'datasource',

// The timeout for connecting to the MySQL server. (Default: 500 milliseconds)
// connectTimeout: 500,

// The timeout for waiting for a connection from the connection pool. (Default: 500 milliseconds)
// So max timeout for get a connection is (connectTimeout + poolWaitTimeout)
// poolWaitTimeout: 500,
});
```

Expand Down
57 changes: 51 additions & 6 deletions src/client.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { AsyncLocalStorage } from 'node:async_hooks';
import { promisify } from 'node:util';
import { setTimeout } from 'node:timers/promises';
import mysql from 'mysql';
import type { Pool } from 'mysql';
import type { PoolConnectionPromisify, RDSClientOptions, TransactionContext, TransactionScope } from './types';
Expand All @@ -10,33 +11,55 @@ import { RDSPoolConfig } from './PoolConfig';
import literals from './literals';
import channels from './channels';
import type { ConnectionMessage, ConnectionEnqueueMessage } from './channels';
import { PoolWaitTimeoutError } from './util/PoolWaitTimeout';

export * from './types';

interface PoolPromisify extends Omit<Pool, 'query'> {
query(sql: string): Promise<any>;

getConnection(): Promise<PoolConnectionPromisify>;

end(): Promise<void>;

_acquiringConnections: any[];
_allConnections: any[];
_freeConnections: any[];
_connectionQueue: any[];
}

export class RDSClient extends Operator {
static get literals() { return literals; }
static get escape() { return mysql.escape; }
static get escapeId() { return mysql.escapeId; }
static get format() { return mysql.format; }
static get raw() { return mysql.raw; }
static get literals() {
return literals;
}

static get escape() {
return mysql.escape;
}

static get escapeId() {
return mysql.escapeId;
}

static get format() {
return mysql.format;
}

static get raw() {
return mysql.raw;
}

static #DEFAULT_STORAGE_KEY = Symbol('RDSClient#storage#default');
static #TRANSACTION_NEST_COUNT = Symbol('RDSClient#transaction#nestCount');

#pool: PoolPromisify;
#connectionStorage: AsyncLocalStorage<TransactionContext>;
#connectionStorageKey: string | symbol;
#poolWaitTimeout: number;

constructor(options: RDSClientOptions) {
super();
options.connectTimeout = options.connectTimeout ?? 500;
const { connectionStorage, connectionStorageKey, ...mysqlOptions } = options;
// get connection options from getConnectionConfig method every time
if (mysqlOptions.getConnectionConfig) {
Expand All @@ -59,6 +82,7 @@ export class RDSClient extends Operator {
});
this.#connectionStorage = connectionStorage || new AsyncLocalStorage();
this.#connectionStorageKey = connectionStorageKey || RDSClient.#DEFAULT_STORAGE_KEY;
this.#poolWaitTimeout = options.poolWaitTimeout ?? 500;
// https://github.com/mysqljs/mysql#pool-events
this.#pool.on('connection', (connection: PoolConnectionPromisify) => {
channels.connectionNew.publish({
Expand Down Expand Up @@ -113,9 +137,30 @@ export class RDSClient extends Operator {
};
}

async waitPoolConnection(abortSignal: AbortSignal) {
const now = performance.now();
await setTimeout(this.#poolWaitTimeout, undefined, { signal: abortSignal });
return performance.now() - now;
}

async getConnectionWithTimeout() {
const connPromise = this.#pool.getConnection();
const timeoutAbortController = new AbortController();
const timeoutPromise = this.waitPoolConnection(timeoutAbortController.signal);
const connOrTimeout = await Promise.race([ connPromise, timeoutPromise ]);
if (typeof connOrTimeout === 'number') {
connPromise.then(conn => {
conn.release();
});
throw new PoolWaitTimeoutError(`get connection timeout after ${connOrTimeout}ms`);
}
timeoutAbortController.abort();
return connPromise;
}

async getConnection() {
try {
const _conn = await this.#pool.getConnection();
const _conn = await this.getConnectionWithTimeout();
const conn = new RDSConnection(_conn);
if (this.beforeQueryHandlers.length > 0) {
for (const handler of this.beforeQueryHandlers) {
Expand Down
6 changes: 6 additions & 0 deletions src/connection.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import assert from 'node:assert';
import { promisify } from 'node:util';
import { Operator } from './operator';
import type { PoolConnectionPromisify } from './types';
Expand All @@ -6,8 +7,11 @@ const kWrapToRDS = Symbol('kWrapToRDS');

export class RDSConnection extends Operator {
conn: PoolConnectionPromisify;
#released: boolean;

constructor(conn: PoolConnectionPromisify) {
super(conn);
this.#released = false;
this.conn = conn;
if (!this.conn[kWrapToRDS]) {
[
Expand All @@ -23,6 +27,8 @@ export class RDSConnection extends Operator {
}

release() {
assert(!this.#released, 'connection was released');
this.#released = true;
return this.conn.release();
}

Expand Down
4 changes: 4 additions & 0 deletions src/transaction.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import type { RDSConnection } from './connection';
import { Operator } from './operator';

let id = 0;
export class RDSTransaction extends Operator {
isCommit = false;
isRollback = false;
conn: RDSConnection | null;
id: number;

constructor(conn: RDSConnection) {
super(conn.conn);
this.id = id++;
this.conn = conn;
}

Expand Down
3 changes: 2 additions & 1 deletion src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ export type GetConnectionConfig = () => ConnectionConfig;
export interface RDSClientOptions extends PoolConfig {
connectionStorageKey?: string;
connectionStorage?: AsyncLocalStorage<Record<PropertyKey, RDSTransaction>>;
getConnectionConfig?: GetConnectionConfig
getConnectionConfig?: GetConnectionConfig;
poolWaitTimeout?: number;
}

export interface PoolConnectionPromisify extends Omit<PoolConnection, 'query'> {
Expand Down
6 changes: 6 additions & 0 deletions src/util/PoolWaitTimeout.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
export class PoolWaitTimeoutError extends Error {
constructor(...args) {
super(...args);
this.name = 'PoolWaitTimeoutError';
}
}
105 changes: 102 additions & 3 deletions test/client.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { AsyncLocalStorage } from 'node:async_hooks';
import { strict as assert } from 'node:assert';
import fs from 'node:fs/promises';
import { setTimeout } from 'node:timers/promises';
import path from 'node:path';
import mm from 'mm';
import { RDSTransaction } from '../src/transaction';
Expand Down Expand Up @@ -298,8 +299,9 @@ describe('test/client.test.ts', () => {
// recovered after unlock.
await conn.query('select * from `ali-sdk-test-user` limit 1;');
} catch (err) {
conn.release();
throw err;
} finally {
conn.release();
}
});

Expand Down Expand Up @@ -353,7 +355,8 @@ describe('test/client.test.ts', () => {
});

it('should throw rollback error with cause error when rollback failed', async () => {
mm(RDSTransaction.prototype, 'rollback', async () => {
mm(RDSTransaction.prototype, 'rollback', async function(this: RDSTransaction) {
this.conn!.release();
throw new Error('fake rollback error');
});
await assert.rejects(
Expand Down Expand Up @@ -501,7 +504,9 @@ describe('test/client.test.ts', () => {
});
};

const [ p1Res, p2Res ] = await Promise.all([ p1(), p2().catch(err => err) ]);
const [ p1Res, p2Res ] = await Promise.all([ p1(), p2().catch(err => {
return err;
}) ]);
assert.strictEqual(p1Res, true);
assert.strictEqual(p2Res.code, 'ER_PARSE_ERROR');
const rows = await db.query('select * from ?? where email=? order by id',
Expand Down Expand Up @@ -1458,4 +1463,98 @@ describe('test/client.test.ts', () => {
assert.equal(counter2After, 4);
});
});

describe('PoolWaitTimeout', () => {
async function longQuery(timeout?: number) {
await db.beginTransactionScope(async conn => {
await setTimeout(timeout ?? 1000);
await conn.query('SELECT 1+1');
});
}

it('should throw error if pool wait timeout', async () => {
const tasks: Array<Promise<void>> = [];
for (let i = 0; i < 10; i++) {
tasks.push(longQuery());
}
const tasksPromise = Promise.all(tasks);
await assert.rejects(async () => {
await longQuery();
}, /get connection timeout after/);
await tasksPromise;
});

it('should release conn to pool', async () => {
const tasks: Array<Promise<void>> = [];
const timeoutTasks: Array<Promise<void>> = [];
// 1. fill the pool
for (let i = 0; i < 10; i++) {
tasks.push(longQuery());
}
// 2. add more conn and wait for timeout
for (let i = 0; i < 10; i++) {
timeoutTasks.push(longQuery());
}
const [ succeedTasks, failedTasks ] = await Promise.all([
Promise.allSettled(tasks),
Promise.allSettled(timeoutTasks),
]);
const succeedCount = succeedTasks.filter(t => t.status === 'fulfilled').length;
assert.equal(succeedCount, 10);

const failedCount = failedTasks.filter(t => t.status === 'rejected').length;
assert.equal(failedCount, 10);

// 3. after pool empty, create new tasks
const retryTasks: Array<Promise<void>> = [];
for (let i = 0; i < 10; i++) {
retryTasks.push(longQuery());
}
await Promise.all(retryTasks);
});

it('should not wait too long', async () => {
const tasks: Array<Promise<void>> = [];
const timeoutTasks: Array<Promise<void>> = [];
const fastTasks: Array<Promise<void>> = [];
const start = performance.now();
// 1. fill the pool
for (let i = 0; i < 10; i++) {
tasks.push(longQuery());
}
const tasksPromise = Promise.allSettled(tasks);
// 2. add more conn and wait for timeout
for (let i = 0; i < 10; i++) {
timeoutTasks.push(longQuery());
}
const timeoutTasksPromise = Promise.allSettled(timeoutTasks);
await setTimeout(600);
// 3. add fast query
for (let i = 0; i < 10; i++) {
fastTasks.push(longQuery(1));
}
const fastTasksPromise = Promise.allSettled(fastTasks);
const [ succeedTasks, failedTasks, fastTaskResults ] = await Promise.all([
tasksPromise,
timeoutTasksPromise,
fastTasksPromise,
]);
const duration = performance.now() - start;
const succeedCount = succeedTasks.filter(t => t.status === 'fulfilled').length;
assert.equal(succeedCount, 10);

const failedCount = failedTasks.filter(t => t.status === 'rejected').length;
assert.equal(failedCount, 10);

const faskTaskSucceedCount = fastTaskResults.filter(t => t.status === 'fulfilled').length;
assert.equal(faskTaskSucceedCount, 10);

// - 10 long queries cost 1000ms
// - 10 timeout queries should be timeout in long query execution so not cost time
// - 10 fast queries wait long query to finish, cost 1ms
// 1000ms + 0ms + 1ms < 1100ms
assert(duration < 1100);
});

});
});

0 comments on commit 7a355d3

Please sign in to comment.