From d1b45628a4a8da9af78e5830d628164c26a0a78f Mon Sep 17 00:00:00 2001 From: Neal Beeken Date: Tue, 14 Jan 2025 14:21:35 -0500 Subject: [PATCH] wip: testing --- src/client-side-encryption/auto_encrypter.ts | 12 +- src/client-side-encryption/state_machine.ts | 1 + src/cmap/connection.ts | 5 +- src/cmap/wire_protocol/on_data.ts | 2 + src/utils.ts | 2 + .../node-specific/abort_signal.test.ts | 365 ++++++++++++++++-- 6 files changed, 343 insertions(+), 44 deletions(-) diff --git a/src/client-side-encryption/auto_encrypter.ts b/src/client-side-encryption/auto_encrypter.ts index 7111df25e5..1d7a9de4c6 100644 --- a/src/client-side-encryption/auto_encrypter.ts +++ b/src/client-side-encryption/auto_encrypter.ts @@ -11,6 +11,7 @@ import { kDecorateResult } from '../constants'; import { getMongoDBClientEncryption } from '../deps'; import { MongoRuntimeError } from '../error'; import { MongoClient, type MongoClientOptions } from '../mongo_client'; +import { type Abortable } from '../mongo_types'; import { MongoDBCollectionNamespace } from '../utils'; import { autoSelectSocketOptions } from './client_encryption'; import * as cryptoCallbacks from './crypto_callbacks'; @@ -372,8 +373,10 @@ export class AutoEncrypter { async encrypt( ns: string, cmd: Document, - options: CommandOptions = {} + options: CommandOptions & Abortable = {} ): Promise { + options.signal?.throwIfAborted(); + if (this._bypassEncryption) { // If `bypassAutoEncryption` has been specified, don't encrypt return cmd; @@ -407,7 +410,12 @@ export class AutoEncrypter { /** * Decrypt a command response */ - async decrypt(response: Uint8Array, options: CommandOptions = {}): Promise { + async decrypt( + response: Uint8Array, + options: CommandOptions & Abortable = {} + ): Promise { + options.signal?.throwIfAborted(); + const context = this._mongocrypt.makeDecryptionContext(response); context.id = this._contextCounter++; diff --git a/src/client-side-encryption/state_machine.ts b/src/client-side-encryption/state_machine.ts index d0236fd8fe..07dad3c578 100644 --- a/src/client-side-encryption/state_machine.ts +++ b/src/client-side-encryption/state_machine.ts @@ -206,6 +206,7 @@ export class StateMachine { let result: Uint8Array | null = null; while (context.state !== MONGOCRYPT_CTX_DONE && context.state !== MONGOCRYPT_CTX_ERROR) { + options.signal?.throwIfAborted(); debug(`[context#${context.id}] ${stateToString.get(context.state) || context.state}`); switch (context.state) { diff --git a/src/cmap/connection.ts b/src/cmap/connection.ts index 6e3f04e59a..40644bf1be 100644 --- a/src/cmap/connection.ts +++ b/src/cmap/connection.ts @@ -474,7 +474,10 @@ export class Connection extends TypedEventEmitter { ); } - for await (const response of this.readMany({ timeoutContext: options.timeoutContext })) { + for await (const response of this.readMany({ + timeoutContext: options.timeoutContext, + signal: options.signal + })) { this.socket.setTimeout(0); const bson = response.parse(); diff --git a/src/cmap/wire_protocol/on_data.ts b/src/cmap/wire_protocol/on_data.ts index 11ccf6a5d1..82dd7b40db 100644 --- a/src/cmap/wire_protocol/on_data.ts +++ b/src/cmap/wire_protocol/on_data.ts @@ -24,6 +24,8 @@ export function onData( emitter: EventEmitter, { timeoutContext, signal }: { timeoutContext?: TimeoutContext } & Abortable ) { + signal?.throwIfAborted(); + // Setup pending events and pending promise lists /** * When the caller has not yet called .next(), we store the diff --git a/src/utils.ts b/src/utils.ts index 8498595b9f..69a25e230f 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1351,6 +1351,8 @@ export const randomBytes = promisify(crypto.randomBytes); * @param name - An event name to wait for */ export async function once(ee: EventEmitter, name: string, options?: Abortable): Promise { + options?.signal?.throwIfAborted(); + const { promise, resolve, reject } = promiseWithResolvers(); const onEvent = (data: T) => resolve(data); const onError = (error: Error) => reject(error); diff --git a/test/integration/node-specific/abort_signal.test.ts b/test/integration/node-specific/abort_signal.test.ts index 881700a2e9..812fcfdfda 100644 --- a/test/integration/node-specific/abort_signal.test.ts +++ b/test/integration/node-specific/abort_signal.test.ts @@ -1,10 +1,15 @@ +import * as child_process from 'node:child_process'; +import * as events from 'node:events'; import * as util from 'node:util'; import { expect } from 'chai'; +import * as sinon from 'sinon'; import { type AbstractCursor, AggregationCursor, + AutoEncrypter, + ClientEncryption, type Collection, type Db, FindCursor, @@ -12,9 +17,16 @@ import { type Log, type MongoClient, ReadPreference, - setDifference + setDifference, + UUID } from '../../mongodb'; -import { DOMException, findLast, sleep } from '../../tools/utils'; +import { + clearFailPoint, + configureFailPoint, + DOMException, + findLast, + sleep +} from '../../tools/utils'; const isAsyncGenerator = (value: any): value is AsyncGenerator => value[Symbol.toStringTag] === 'AsyncGenerator'; @@ -33,34 +45,22 @@ function getAllProps(value) { describe('AbortSignal support', () => { let client: MongoClient; let db: Db; - let collection: Collection<{ a: number }>; + let collection: Collection<{ a: number; ssn: string }>; const logs: Log[] = []; beforeEach(async function () { - const utilClient = this.configuration.newClient(); - try { - await utilClient.db('abortSignal').collection('support').deleteMany({}); - await utilClient - .db('abortSignal') - .collection('support') - .insertMany([{ a: 1 }, { a: 2 }, { a: 3 }]); - } finally { - await utilClient.close(); - } - logs.length = 0; client = this.configuration.newClient( {}, { + monitorCommands: true, + appName: 'abortSignalClient', __enableMongoLogger: true, - __internalLoggerConfig: { - MONGODB_LOG_SERVER_SELECTION: 'debug' - }, - mongodbLogPath: { - write: log => logs.push(log) - }, - serverSelectionTimeoutMS: 5000 + __internalLoggerConfig: { MONGODB_LOG_SERVER_SELECTION: 'debug' }, + mongodbLogPath: { write: log => logs.push(log) }, + serverSelectionTimeoutMS: 10_000, + maxPoolSize: 1 } ); await client.connect(); @@ -76,11 +76,20 @@ describe('AbortSignal support', () => { } finally { await utilClient.close(); } - await client.close(); + await client?.close(); }); - function testCursor(name: string, constructor: any) { - describe(`when ${name}() is given a signal`, () => { + function testCursor(cursorName: string, constructor: any) { + let method; + let filter; + beforeEach(function () { + method = (cursorName === 'listCollections' ? db[cursorName] : collection[cursorName]).bind( + cursorName === 'listCollections' ? db : collection + ); + filter = cursorName === 'aggregate' ? [] : {}; + }); + + describe(`when ${cursorName}() is given a signal`, () => { const cursorAPIs = { tryNext: [], hasNext: [], @@ -149,7 +158,7 @@ describe('AbortSignal support', () => { signal = controller.signal; controller.abort(); - cursor = collection.find({}, { signal }); + cursor = method(cursorName === 'aggregate' ? [] : {}, { signal }); }); afterEach(async () => { @@ -172,7 +181,7 @@ describe('AbortSignal support', () => { beforeEach(() => { controller = new AbortController(); signal = controller.signal; - cursor = collection.find({}, { signal }); + cursor = method(filter, { signal }); }); afterEach(async () => { @@ -204,21 +213,18 @@ describe('AbortSignal support', () => { function test(cursorAPI, args) { let controller: AbortController; let signal: AbortSignal; - let cursor: FindCursor<{ a: number }>; + let cursor: AbstractCursor<{ a: number }>; beforeEach(() => { controller = new AbortController(); signal = controller.signal; - cursor = collection.find( - {}, - { - signal, - // Pick an unselectable server - readPreference: new ReadPreference('secondary', [ - { something: 'that does not exist' } - ]) - } - ); + cursor = method(filter, { + signal, + // Pick an unselectable server + readPreference: new ReadPreference('secondary', [ + { something: 'that does not exist' } + ]) + }); }); afterEach(async () => { @@ -229,10 +235,14 @@ describe('AbortSignal support', () => { const willBeResult = captureCursorError(cursor, cursorAPI, args); await sleep(3); - expect(findLast(logs, l => l.operation === 'find')).to.have.property( - 'message', - 'Waiting for suitable server to become available' - ); + expect( + findLast( + logs, + l => + l.operation === cursorName && + l.message === 'Waiting for suitable server to become available' + ) + ).to.exist; controller.abort(); const start = performance.now(); @@ -248,6 +258,279 @@ describe('AbortSignal support', () => { test(cursorAPI, args); } }); + + describe('and the signal is aborted during connection checkout', () => { + function test(cursorAPI, args) { + let controller: AbortController; + let signal: AbortSignal; + let cursor: AbstractCursor<{ a: number }>; + + beforeEach(async function () { + await configureFailPoint(this.configuration, { + configureFailPoint: 'failCommand', + mode: { times: 1 }, + data: { + appName: 'abortSignalClient', + failCommands: [cursorName], + blockConnection: true, + blockTimeMS: 300 + } + }); + + controller = new AbortController(); + signal = controller.signal; + cursor = method(filter, { signal }); + }); + + afterEach(async function () { + await clearFailPoint(this.configuration); + await cursor?.close(); + }); + + it(`rejects ${cursorAPI.toString()}`, async () => { + const checkoutSucceededFirst = events.once(client, 'connectionCheckedOut'); + const checkoutStartedBlocked = events.once(client, 'connectionCheckOutStarted'); + + const _ = captureCursorError(cursor, cursorAPI, args); + const willBeResultBlocked = captureCursorError(cursor, cursorAPI, args); + + await checkoutSucceededFirst; + await checkoutStartedBlocked; + + controller.abort(); + const result = await willBeResultBlocked; + + expect(result).to.be.instanceOf(DOMException); + }); + } + + for (const [cursorAPI, { value: args }] of getAllProps(cursorAPIs)) { + test(cursorAPI, args); + } + }); + + describe('and the signal is aborted during connection write', () => { + function test(cursorAPI, args) { + let controller: AbortController; + let signal: AbortSignal; + let cursor: AbstractCursor<{ a: number }>; + + beforeEach(async function () { + controller = new AbortController(); + signal = controller.signal; + cursor = method(filter, { signal }); + }); + + afterEach(async function () { + sinon.restore(); + await cursor?.close(); + }); + + it(`rejects ${cursorAPI.toString()}`, async () => { + await db.command({ ping: 1 }, { readPreference: 'primary' }); // fill the connection pool with 1 connection. + + // client.once('commandStarted', () => controller.abort()); + const willBeResultBlocked = captureCursorError(cursor, cursorAPI, args); + + for (const [, server] of client.topology.s.servers) { + //@ts-expect-error: private property + for (const connection of server.pool.connections) { + //@ts-expect-error: private property + const stub = sinon.stub(connection.socket, 'write').callsFake(function (...args) { + controller.abort(); + sleep(1).then(() => { + stub.wrappedMethod.apply(this, args); + this.emit('drain'); + }); + return false; + }); + } + } + + const result = await willBeResultBlocked; + + expect(result).to.be.instanceOf(DOMException); + }); + } + + for (const [cursorAPI, { value: args }] of getAllProps(cursorAPIs)) { + test(cursorAPI, args); + } + }); + + describe('and the signal is aborted during connection read', () => { + function test(cursorAPI, args) { + let controller: AbortController; + let signal: AbortSignal; + let cursor: AbstractCursor<{ a: number }>; + + beforeEach(async function () { + await configureFailPoint(this.configuration, { + configureFailPoint: 'failCommand', + mode: { times: 1 }, + data: { + appName: 'abortSignalClient', + failCommands: [cursorName], + blockConnection: true, + blockTimeMS: 300 + } + }); + + controller = new AbortController(); + signal = controller.signal; + cursor = method(filter, { signal }); + }); + + afterEach(async function () { + await clearFailPoint(this.configuration); + await cursor?.close(); + }); + + it(`rejects ${cursorAPI.toString()}`, async () => { + await db.command({ ping: 1 }, { readPreference: 'primary' }); // fill the connection pool with 1 connection. + + client.on('commandStarted', e => e.commandName === cursorName && controller.abort()); + const willBeResultBlocked = captureCursorError(cursor, cursorAPI, args); + + const result = await willBeResultBlocked; + + expect(result).to.be.instanceOf(DOMException); + }); + } + + for (const [cursorAPI, { value: args }] of getAllProps(cursorAPIs)) { + test(cursorAPI, args); + } + }); + + const fleMetadata: MongoDBMetadataUI = { requires: { clientSideEncryption: true } }; + + describe('setup fle', fleMetadata, () => { + let autoEncryption; + + before(async function () { + autoEncryption = { + keyVaultNamespace: 'admin.datakeys', + kmsProviders: { + kmip: { endpoint: 'localhost:5698' } + }, + tlsOptions: { + kmip: { + tlsCAFile: process.env.KMIP_TLS_CA_FILE, + tlsCertificateKeyFile: process.env.KMIP_TLS_CERT_FILE + } + }, + encryptedFieldsMap: { + 'abortSignal.support': { + fields: [ + { + path: 'ssn', + keyId: null, + bsonType: 'string' + } + ] + } + } + }; + + let utilClient = this.configuration.newClient({}, {}); + + try { + await utilClient.db('abortSignal').collection('support').drop({}); + + const clientEncryption = new ClientEncryption(utilClient, { + ...autoEncryption, + encryptedFieldsMap: undefined + }); + + autoEncryption.encryptedFieldsMap['abortSignal.support'] = ( + await clientEncryption.createEncryptedCollection( + utilClient.db('abortSignal'), + 'support', + { + provider: 'kmip', + createCollectionOptions: { + encryptedFields: autoEncryption.encryptedFieldsMap['abortSignal.support'] + } + } + ) + ).encryptedFields; + } finally { + await utilClient.close(); + } + + utilClient = this.configuration.newClient({}, { autoEncryption }); + try { + await utilClient + .db('abortSignal') + .collection('support') + .insertMany([ + { a: 1, ssn: '0000-00-0001' }, + { a: 2, ssn: '0000-00-0002' }, + { a: 3, ssn: '0000-00-0003' } + ]); + } finally { + await utilClient.close(); + } + }); + + beforeEach(async function () { + await client?.close(); + client = undefined; + + logs.length = 0; + + client = this.configuration.newClient( + {}, + { + autoEncryption, + monitorCommands: true, + appName: 'abortSignalClient', + __enableMongoLogger: true, + __internalLoggerConfig: { MONGODB_LOG_SERVER_SELECTION: 'debug' }, + mongodbLogPath: { write: log => logs.push(log) }, + serverSelectionTimeoutMS: 10_000, + maxPoolSize: 1 + } + ); + await client.connect(); + db = client.db('abortSignal'); + collection = db.collection('support'); + }); + + describe('and the signal is aborted during command encryption', fleMetadata, () => { + function test(cursorAPI, args) { + let controller: AbortController; + let signal: AbortSignal; + let cursor: AbstractCursor<{ a: number }>; + + beforeEach(async function () { + controller = new AbortController(); + signal = controller.signal; + cursor = method(filter, { signal }); + }); + + afterEach(async function () { + sinon.restore(); + await cursor?.close(); + }); + + it(`rejects ${cursorAPI.toString()}`, fleMetadata, async () => { + await client.connect(); + + const willBeResultBlocked = captureCursorError(cursor, cursorAPI, args); + + const result = await willBeResultBlocked; + + expect(result).to.be.instanceOf(DOMException); + }); + } + + for (const [cursorAPI, { value: args }] of getAllProps(cursorAPIs)) { + test(cursorAPI, args); + } + }); + }); }); }