Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(NODE-6258): add signal support to find and aggregate #4364

Merged
merged 33 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ab02d53
feat(NODE-6258): add signal support to cursor APIs
nbbeeken Jan 15, 2025
0d1c165
chore: readmany options
nbbeeken Jan 17, 2025
3de15cf
chore: drain options
nbbeeken Jan 17, 2025
96c3612
chore: explicit signal
nbbeeken Jan 17, 2025
ee00a38
docs: fix up api docs
nbbeeken Jan 17, 2025
8955d8e
test: better helper name
nbbeeken Jan 17, 2025
8e6ec04
test: update name
nbbeeken Jan 17, 2025
a73940e
test: improve iteration test organization
nbbeeken Jan 17, 2025
45a4b65
test: cruft
nbbeeken Jan 17, 2025
e9338cb
feat: make sure connections are closed after abort if aborted during …
nbbeeken Jan 17, 2025
152be95
test: remove redundant fle tests
nbbeeken Jan 17, 2025
5af48af
chore: make findLast simple
nbbeeken Jan 17, 2025
6e8bd46
test: no kill cursors on lb and don't wait on connection close
nbbeeken Jan 17, 2025
263b185
chore: lint
nbbeeken Jan 21, 2025
017c77e
chore: consistent cursor option types
nbbeeken Jan 22, 2025
7683fc9
docs: connection churn
nbbeeken Jan 22, 2025
ef231ee
fix: do not hang when the signal is aborted
nbbeeken Jan 22, 2025
ec2feee
fix: state checks
nbbeeken Jan 23, 2025
93b88fb
docs
nbbeeken Jan 23, 2025
89beb22
test: getaddrinfo error codes aren't on windows
nbbeeken Jan 23, 2025
c2ef0a6
experimental
nbbeeken Jan 23, 2025
f1e4f86
test filter
nbbeeken Jan 23, 2025
aabd070
check time pass
nbbeeken Jan 23, 2025
a46a919
where
nbbeeken Jan 24, 2025
b69286b
test: add where test
nbbeeken Jan 24, 2025
7f14566
test: stream test flaky
nbbeeken Jan 24, 2025
c24cf28
fix: reauth must finish before check in
nbbeeken Jan 24, 2025
dc52b67
test: flake
nbbeeken Jan 24, 2025
76fee52
test: lb
nbbeeken Jan 24, 2025
f05e8bc
test: flake
nbbeeken Jan 24, 2025
f30dc05
fix options name, docs for abortable, docs for signal
nbbeeken Jan 24, 2025
907a1a8
Merge branch 'main' into NODE-6258-abortsignal
W-A-James Jan 24, 2025
0e6f566
docs
nbbeeken Jan 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions src/client-side-encryption/auto_encrypter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { kDecorateResult } from '../constants';
import { getMongoDBClientEncryption } from '../deps';
import { MongoRuntimeError } from '../error';
import { MongoClient, type MongoClientOptions } from '../mongo_client';
import { type Abortable } from '../mongo_types';
import { MongoDBCollectionNamespace } from '../utils';
import { autoSelectSocketOptions } from './client_encryption';
import * as cryptoCallbacks from './crypto_callbacks';
Expand Down Expand Up @@ -372,8 +373,10 @@ export class AutoEncrypter {
async encrypt(
ns: string,
cmd: Document,
options: CommandOptions = {}
options: CommandOptions & Abortable = {}
): Promise<Document | Uint8Array> {
options.signal?.throwIfAborted();

if (this._bypassEncryption) {
// If `bypassAutoEncryption` has been specified, don't encrypt
return cmd;
Expand All @@ -398,7 +401,7 @@ export class AutoEncrypter {
socketOptions: autoSelectSocketOptions(this._client.s.options)
});

return deserialize(await stateMachine.execute(this, context, options.timeoutContext), {
return deserialize(await stateMachine.execute(this, context, options), {
promoteValues: false,
promoteLongs: false
});
Expand All @@ -407,7 +410,12 @@ export class AutoEncrypter {
/**
* Decrypt a command response
*/
async decrypt(response: Uint8Array, options: CommandOptions = {}): Promise<Uint8Array> {
async decrypt(
response: Uint8Array,
options: CommandOptions & Abortable = {}
): Promise<Uint8Array> {
options.signal?.throwIfAborted();

const context = this._mongocrypt.makeDecryptionContext(response);

context.id = this._contextCounter++;
Expand All @@ -419,11 +427,7 @@ export class AutoEncrypter {
socketOptions: autoSelectSocketOptions(this._client.s.options)
});

return await stateMachine.execute(
this,
context,
options.timeoutContext?.csotEnabled() ? options.timeoutContext : undefined
);
return await stateMachine.execute(this, context, options);
}

/**
Expand Down
10 changes: 6 additions & 4 deletions src/client-side-encryption/client_encryption.ts
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ export class ClientEncryption {
TimeoutContext.create(resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS }));

const dataKey = deserialize(
await stateMachine.execute(this, context, timeoutContext)
await stateMachine.execute(this, context, { timeoutContext })
) as DataKey;

const { db: dbName, collection: collectionName } = MongoDBCollectionNamespace.fromString(
Expand Down Expand Up @@ -293,7 +293,9 @@ export class ClientEncryption {
resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS })
);

const { v: dataKeys } = deserialize(await stateMachine.execute(this, context, timeoutContext));
const { v: dataKeys } = deserialize(
await stateMachine.execute(this, context, { timeoutContext })
);
if (dataKeys.length === 0) {
return {};
}
Expand Down Expand Up @@ -696,7 +698,7 @@ export class ClientEncryption {
? TimeoutContext.create(resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS }))
: undefined;

const { v } = deserialize(await stateMachine.execute(this, context, timeoutContext));
const { v } = deserialize(await stateMachine.execute(this, context, { timeoutContext }));

return v;
}
Expand Down Expand Up @@ -780,7 +782,7 @@ export class ClientEncryption {
this._timeoutMS != null
? TimeoutContext.create(resolveTimeoutOptions(this._client, { timeoutMS: this._timeoutMS }))
: undefined;
const { v } = deserialize(await stateMachine.execute(this, context, timeoutContext));
const { v } = deserialize(await stateMachine.execute(this, context, { timeoutContext }));
return v;
}
}
Expand Down
116 changes: 80 additions & 36 deletions src/client-side-encryption/state_machine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@ import { CursorTimeoutContext } from '../cursor/abstract_cursor';
import { getSocks, type SocksLib } from '../deps';
import { MongoOperationTimeoutError } from '../error';
import { type MongoClient, type MongoClientOptions } from '../mongo_client';
import { type Abortable } from '../mongo_types';
import { Timeout, type TimeoutContext, TimeoutError } from '../timeout';
import { BufferPool, MongoDBCollectionNamespace, promiseWithResolvers } from '../utils';
import {
addAbortListener,
BufferPool,
kDispose,
MongoDBCollectionNamespace,
promiseWithResolvers
} from '../utils';
import { autoSelectSocketOptions, type DataKey } from './client_encryption';
import { MongoCryptError } from './errors';
import { type MongocryptdManager } from './mongocryptd_manager';
Expand Down Expand Up @@ -189,7 +196,7 @@ export class StateMachine {
async execute(
executor: StateMachineExecutable,
context: MongoCryptContext,
timeoutContext?: TimeoutContext
options: { timeoutContext?: TimeoutContext } & Abortable
): Promise<Uint8Array> {
const keyVaultNamespace = executor._keyVaultNamespace;
const keyVaultClient = executor._keyVaultClient;
Expand All @@ -199,6 +206,7 @@ export class StateMachine {
let result: Uint8Array | null = null;

while (context.state !== MONGOCRYPT_CTX_DONE && context.state !== MONGOCRYPT_CTX_ERROR) {
options.signal?.throwIfAborted();
debug(`[context#${context.id}] ${stateToString.get(context.state) || context.state}`);

switch (context.state) {
Expand All @@ -214,7 +222,7 @@ export class StateMachine {
metaDataClient,
context.ns,
filter,
timeoutContext
options
);
if (collInfo) {
context.addMongoOperationResponse(collInfo);
Expand All @@ -235,9 +243,9 @@ export class StateMachine {
// When we are using the shared library, we don't have a mongocryptd manager.
const markedCommand: Uint8Array = mongocryptdManager
? await mongocryptdManager.withRespawn(
this.markCommand.bind(this, mongocryptdClient, context.ns, command, timeoutContext)
this.markCommand.bind(this, mongocryptdClient, context.ns, command, options)
)
: await this.markCommand(mongocryptdClient, context.ns, command, timeoutContext);
: await this.markCommand(mongocryptdClient, context.ns, command, options);

context.addMongoOperationResponse(markedCommand);
context.finishMongoOperation();
Expand All @@ -246,12 +254,7 @@ export class StateMachine {

case MONGOCRYPT_CTX_NEED_MONGO_KEYS: {
const filter = context.nextMongoOperation();
const keys = await this.fetchKeys(
keyVaultClient,
keyVaultNamespace,
filter,
timeoutContext
);
const keys = await this.fetchKeys(keyVaultClient, keyVaultNamespace, filter, options);

if (keys.length === 0) {
// See docs on EMPTY_V
Expand All @@ -273,7 +276,7 @@ export class StateMachine {
}

case MONGOCRYPT_CTX_NEED_KMS: {
await Promise.all(this.requests(context, timeoutContext));
await Promise.all(this.requests(context, options));
context.finishKMSRequests();
break;
}
Expand Down Expand Up @@ -315,11 +318,13 @@ export class StateMachine {
* @param kmsContext - A C++ KMS context returned from the bindings
* @returns A promise that resolves when the KMS reply has be fully parsed
*/
async kmsRequest(request: MongoCryptKMSRequest, timeoutContext?: TimeoutContext): Promise<void> {
async kmsRequest(
request: MongoCryptKMSRequest,
options?: { timeoutContext?: TimeoutContext } & Abortable
): Promise<void> {
const parsedUrl = request.endpoint.split(':');
const port = parsedUrl[1] != null ? Number.parseInt(parsedUrl[1], 10) : HTTPS_PORT;
const socketOptions = autoSelectSocketOptions(this.options.socketOptions || {});
const options: tls.ConnectionOptions & {
const socketOptions: tls.ConnectionOptions & {
host: string;
port: number;
autoSelectFamily?: boolean;
Expand All @@ -328,7 +333,7 @@ export class StateMachine {
host: parsedUrl[0],
servername: parsedUrl[0],
port,
...socketOptions
...autoSelectSocketOptions(this.options.socketOptions || {})
};
const message = request.message;
const buffer = new BufferPool();
Expand Down Expand Up @@ -363,7 +368,7 @@ export class StateMachine {
throw error;
}
try {
await this.setTlsOptions(providerTlsOptions, options);
await this.setTlsOptions(providerTlsOptions, socketOptions);
} catch (err) {
throw onerror(err);
}
Expand All @@ -380,23 +385,25 @@ export class StateMachine {
.once('close', () => rejectOnNetSocketError(onclose()))
.once('connect', () => resolveOnNetSocketConnect());

let abortListener;

try {
if (this.options.proxyOptions && this.options.proxyOptions.proxyHost) {
const netSocketOptions = {
...socketOptions,
host: this.options.proxyOptions.proxyHost,
port: this.options.proxyOptions.proxyPort || 1080,
...socketOptions
port: this.options.proxyOptions.proxyPort || 1080
};
netSocket.connect(netSocketOptions);
await willConnect;

try {
socks ??= loadSocks();
options.socket = (
socketOptions.socket = (
await socks.SocksClient.createConnection({
existing_socket: netSocket,
command: 'connect',
destination: { host: options.host, port: options.port },
destination: { host: socketOptions.host, port: socketOptions.port },
proxy: {
// host and port are ignored because we pass existing_socket
host: 'iLoveJavaScript',
Expand All @@ -412,7 +419,7 @@ export class StateMachine {
}
}

socket = tls.connect(options, () => {
socket = tls.connect(socketOptions, () => {
socket.write(message);
});

Expand All @@ -422,6 +429,11 @@ export class StateMachine {
resolve
} = promiseWithResolvers<void>();

abortListener = addAbortListener(options?.signal, function () {
destroySockets();
rejectOnTlsSocketError(this.reason);
});

socket
.once('error', err => rejectOnTlsSocketError(onerror(err)))
.once('close', () => rejectOnTlsSocketError(onclose()))
Expand All @@ -436,8 +448,11 @@ export class StateMachine {
resolve();
}
});
await (timeoutContext?.csotEnabled()
? Promise.all([willResolveKmsRequest, Timeout.expires(timeoutContext?.remainingTimeMS)])
await (options?.timeoutContext?.csotEnabled()
? Promise.all([
willResolveKmsRequest,
Timeout.expires(options.timeoutContext?.remainingTimeMS)
])
: willResolveKmsRequest);
} catch (error) {
if (error instanceof TimeoutError)
Expand All @@ -446,16 +461,17 @@ export class StateMachine {
} finally {
// There's no need for any more activity on this socket at this point.
destroySockets();
abortListener?.[kDispose]();
}
}

*requests(context: MongoCryptContext, timeoutContext?: TimeoutContext) {
*requests(context: MongoCryptContext, options?: { timeoutContext?: TimeoutContext } & Abortable) {
for (
let request = context.nextKMSRequest();
request != null;
request = context.nextKMSRequest()
) {
yield this.kmsRequest(request, timeoutContext);
yield this.kmsRequest(request, options);
}
}

Expand Down Expand Up @@ -516,14 +532,16 @@ export class StateMachine {
client: MongoClient,
ns: string,
filter: Document,
timeoutContext?: TimeoutContext
options?: { timeoutContext?: TimeoutContext } & Abortable
): Promise<Uint8Array | null> {
const { db } = MongoDBCollectionNamespace.fromString(ns);

const cursor = client.db(db).listCollections(filter, {
promoteLongs: false,
promoteValues: false,
timeoutContext: timeoutContext && new CursorTimeoutContext(timeoutContext, Symbol())
timeoutContext:
options?.timeoutContext && new CursorTimeoutContext(options?.timeoutContext, Symbol()),
signal: options?.signal
});

// There is always exactly zero or one matching documents, so this should always exhaust the cursor
Expand All @@ -547,17 +565,30 @@ export class StateMachine {
client: MongoClient,
ns: string,
command: Uint8Array,
timeoutContext?: TimeoutContext
options?: { timeoutContext?: TimeoutContext } & Abortable
): Promise<Uint8Array> {
const { db } = MongoDBCollectionNamespace.fromString(ns);
const bsonOptions = { promoteLongs: false, promoteValues: false };
const rawCommand = deserialize(command, bsonOptions);

const commandOptions: {
timeoutMS?: number;
signal?: AbortSignal;
} = {
timeoutMS: undefined,
signal: undefined
};

if (options?.timeoutContext?.csotEnabled()) {
commandOptions.timeoutMS = options.timeoutContext.remainingTimeMS;
}
if (options?.signal) {
commandOptions.signal = options.signal;
}

const response = await client.db(db).command(rawCommand, {
...bsonOptions,
...(timeoutContext?.csotEnabled()
? { timeoutMS: timeoutContext?.remainingTimeMS }
: undefined)
...commandOptions
});

return serialize(response, this.bsonOptions);
Expand All @@ -575,17 +606,30 @@ export class StateMachine {
client: MongoClient,
keyVaultNamespace: string,
filter: Uint8Array,
timeoutContext?: TimeoutContext
options?: { timeoutContext?: TimeoutContext } & Abortable
): Promise<Array<DataKey>> {
const { db: dbName, collection: collectionName } =
MongoDBCollectionNamespace.fromString(keyVaultNamespace);

const commandOptions: {
timeoutContext?: CursorTimeoutContext;
signal?: AbortSignal;
} = {
timeoutContext: undefined,
signal: undefined
};

if (options?.timeoutContext != null) {
commandOptions.timeoutContext = new CursorTimeoutContext(options.timeoutContext, Symbol());
}
if (options?.signal != null) {
commandOptions.signal = options.signal;
}

return client
.db(dbName)
.collection<DataKey>(collectionName, { readConcern: { level: 'majority' } })
.find(deserialize(filter), {
timeoutContext: timeoutContext && new CursorTimeoutContext(timeoutContext, Symbol())
})
.find(deserialize(filter), commandOptions)
.toArray();
}
}
Loading
Loading