Skip to content

Commit 01cbfe7

Browse files
fix: add guards against possible memory overflow in find and aggregate tools MCP-21 (#536)
1 parent c10955a commit 01cbfe7

17 files changed

+1116
-78
lines changed

src/common/config.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import levenshtein from "ts-levenshtein";
99

1010
// From: https://github.com/mongodb-js/mongosh/blob/main/packages/cli-repl/src/arg-parser.ts
1111
const OPTIONS = {
12+
number: ["maxDocumentsPerQuery", "maxBytesPerQuery"],
1213
string: [
1314
"apiBaseUrl",
1415
"apiClientId",
@@ -98,6 +99,7 @@ const OPTIONS = {
9899

99100
interface Options {
100101
string: string[];
102+
number: string[];
101103
boolean: string[];
102104
array: string[];
103105
alias: Record<string, string>;
@@ -106,6 +108,7 @@ interface Options {
106108

107109
export const ALL_CONFIG_KEYS = new Set(
108110
(OPTIONS.string as readonly string[])
111+
.concat(OPTIONS.number)
109112
.concat(OPTIONS.array)
110113
.concat(OPTIONS.boolean)
111114
.concat(Object.keys(OPTIONS.alias))
@@ -175,6 +178,8 @@ export interface UserConfig extends CliOptions {
175178
loggers: Array<"stderr" | "disk" | "mcp">;
176179
idleTimeoutMs: number;
177180
notificationTimeoutMs: number;
181+
maxDocumentsPerQuery: number;
182+
maxBytesPerQuery: number;
178183
atlasTemporaryDatabaseUserLifetimeMs: number;
179184
}
180185

@@ -202,6 +207,8 @@ export const defaultUserConfig: UserConfig = {
202207
idleTimeoutMs: 10 * 60 * 1000, // 10 minutes
203208
notificationTimeoutMs: 9 * 60 * 1000, // 9 minutes
204209
httpHeaders: {},
210+
maxDocumentsPerQuery: 100, // By default, we only fetch a maximum 100 documents per query / aggregation
211+
maxBytesPerQuery: 16 * 1024 * 1024, // By default, we only return ~16 mb of data per query / aggregation
205212
atlasTemporaryDatabaseUserLifetimeMs: 4 * 60 * 60 * 1000, // 4 hours
206213
};
207214

src/common/logger.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ export const LogId = {
4444
mongodbConnectFailure: mongoLogId(1_004_001),
4545
mongodbDisconnectFailure: mongoLogId(1_004_002),
4646
mongodbConnectTry: mongoLogId(1_004_003),
47+
mongodbCursorCloseError: mongoLogId(1_004_004),
4748

4849
toolUpdateFailure: mongoLogId(1_005_001),
4950
resourceUpdateFailure: mongoLogId(1_005_002),
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import { calculateObjectSize } from "bson";
2+
import type { AggregationCursor, FindCursor } from "mongodb";
3+
4+
export function getResponseBytesLimit(
5+
toolResponseBytesLimit: number | undefined | null,
6+
configuredMaxBytesPerQuery: unknown
7+
): {
8+
cappedBy: "config.maxBytesPerQuery" | "tool.responseBytesLimit" | undefined;
9+
limit: number;
10+
} {
11+
const configuredLimit: number = parseInt(String(configuredMaxBytesPerQuery), 10);
12+
13+
// Setting configured maxBytesPerQuery to negative, zero or nullish is
14+
// equivalent to disabling the max limit applied on documents
15+
const configuredLimitIsNotApplicable = Number.isNaN(configuredLimit) || configuredLimit <= 0;
16+
17+
// It's possible to have tool parameter responseBytesLimit as null or
18+
// negative values in which case we consider that no limit is to be
19+
// applied from tool call perspective unless we have a maxBytesPerQuery
20+
// configured.
21+
const toolResponseLimitIsNotApplicable = typeof toolResponseBytesLimit !== "number" || toolResponseBytesLimit <= 0;
22+
23+
if (configuredLimitIsNotApplicable) {
24+
return {
25+
cappedBy: toolResponseLimitIsNotApplicable ? undefined : "tool.responseBytesLimit",
26+
limit: toolResponseLimitIsNotApplicable ? 0 : toolResponseBytesLimit,
27+
};
28+
}
29+
30+
if (toolResponseLimitIsNotApplicable) {
31+
return { cappedBy: "config.maxBytesPerQuery", limit: configuredLimit };
32+
}
33+
34+
return {
35+
cappedBy: configuredLimit < toolResponseBytesLimit ? "config.maxBytesPerQuery" : "tool.responseBytesLimit",
36+
limit: Math.min(toolResponseBytesLimit, configuredLimit),
37+
};
38+
}
39+
40+
/**
41+
* This function attempts to put a guard rail against accidental memory overflow
42+
* on the MCP server.
43+
*
44+
* The cursor is iterated until we can predict that fetching next doc won't
45+
* exceed the derived limit on number of bytes for the tool call. The derived
46+
* limit takes into account the limit provided from the Tool's interface and the
47+
* configured maxBytesPerQuery for the server.
48+
*/
49+
export async function collectCursorUntilMaxBytesLimit<T = unknown>({
50+
cursor,
51+
toolResponseBytesLimit,
52+
configuredMaxBytesPerQuery,
53+
abortSignal,
54+
}: {
55+
cursor: FindCursor<T> | AggregationCursor<T>;
56+
toolResponseBytesLimit: number | undefined | null;
57+
configuredMaxBytesPerQuery: unknown;
58+
abortSignal?: AbortSignal;
59+
}): Promise<{ cappedBy: "config.maxBytesPerQuery" | "tool.responseBytesLimit" | undefined; documents: T[] }> {
60+
const { limit: maxBytesPerQuery, cappedBy } = getResponseBytesLimit(
61+
toolResponseBytesLimit,
62+
configuredMaxBytesPerQuery
63+
);
64+
65+
// It's possible to have no limit on the cursor response by setting both the
66+
// config.maxBytesPerQuery and tool.responseBytesLimit to nullish or
67+
// negative values.
68+
if (maxBytesPerQuery <= 0) {
69+
return {
70+
cappedBy,
71+
documents: await cursor.toArray(),
72+
};
73+
}
74+
75+
let wasCapped: boolean = false;
76+
let totalBytes = 0;
77+
const bufferedDocuments: T[] = [];
78+
while (true) {
79+
if (abortSignal?.aborted) {
80+
break;
81+
}
82+
83+
// If the cursor is empty then there is nothing for us to do anymore.
84+
const nextDocument = await cursor.tryNext();
85+
if (!nextDocument) {
86+
break;
87+
}
88+
89+
const nextDocumentSize = calculateObjectSize(nextDocument);
90+
if (totalBytes + nextDocumentSize >= maxBytesPerQuery) {
91+
wasCapped = true;
92+
break;
93+
}
94+
95+
totalBytes += nextDocumentSize;
96+
bufferedDocuments.push(nextDocument);
97+
}
98+
99+
return {
100+
cappedBy: wasCapped ? cappedBy : undefined,
101+
documents: bufferedDocuments,
102+
};
103+
}

src/helpers/constants.ts

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/**
2+
* A cap for the maxTimeMS used for FindCursor.countDocuments.
3+
*
4+
* The number is relatively smaller because we expect the count documents query
5+
* to be finished sooner if not by the time the batch of documents is retrieved
6+
* so that count documents query don't hold the final response back.
7+
*/
8+
export const QUERY_COUNT_MAX_TIME_MS_CAP: number = 10_000;
9+
10+
/**
11+
* A cap for the maxTimeMS used for counting resulting documents of an
12+
* aggregation.
13+
*/
14+
export const AGG_COUNT_MAX_TIME_MS_CAP: number = 60_000;
15+
16+
export const ONE_MB: number = 1 * 1024 * 1024;
17+
18+
/**
19+
* A map of applied limit on cursors to a text that is supposed to be sent as
20+
* response to LLM
21+
*/
22+
export const CURSOR_LIMITS_TO_LLM_TEXT = {
23+
"config.maxDocumentsPerQuery": "server's configured - maxDocumentsPerQuery",
24+
"config.maxBytesPerQuery": "server's configured - maxBytesPerQuery",
25+
"tool.responseBytesLimit": "tool's parameter - responseBytesLimit",
26+
} as const;
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
type OperationCallback<OperationResult> = () => Promise<OperationResult>;
2+
3+
export async function operationWithFallback<OperationResult, FallbackValue>(
4+
performOperation: OperationCallback<OperationResult>,
5+
fallback: FallbackValue
6+
): Promise<OperationResult | FallbackValue> {
7+
try {
8+
return await performOperation();
9+
} catch {
10+
return fallback;
11+
}
12+
}
Lines changed: 135 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,25 @@
11
import { z } from "zod";
2+
import type { AggregationCursor } from "mongodb";
23
import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
4+
import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
35
import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js";
4-
import type { ToolArgs, OperationType } from "../../tool.js";
6+
import type { ToolArgs, OperationType, ToolExecutionContext } from "../../tool.js";
57
import { formatUntrustedData } from "../../tool.js";
68
import { checkIndexUsage } from "../../../helpers/indexCheck.js";
7-
import { EJSON } from "bson";
9+
import { type Document, EJSON } from "bson";
810
import { ErrorCodes, MongoDBError } from "../../../common/errors.js";
11+
import { collectCursorUntilMaxBytesLimit } from "../../../helpers/collectCursorUntilMaxBytes.js";
12+
import { operationWithFallback } from "../../../helpers/operationWithFallback.js";
13+
import { AGG_COUNT_MAX_TIME_MS_CAP, ONE_MB, CURSOR_LIMITS_TO_LLM_TEXT } from "../../../helpers/constants.js";
914
import { zEJSON } from "../../args.js";
15+
import { LogId } from "../../../common/logger.js";
1016

1117
export const AggregateArgs = {
1218
pipeline: z.array(zEJSON()).describe("An array of aggregation stages to execute"),
19+
responseBytesLimit: z.number().optional().default(ONE_MB).describe(`\
20+
The maximum number of bytes to return in the response. This value is capped by the server’s configured maxBytesPerQuery and cannot be exceeded. \
21+
Note to LLM: If the entire aggregation result is required, use the "export" tool instead of increasing this limit.\
22+
`),
1323
};
1424

1525
export class AggregateTool extends MongoDBToolBase {
@@ -21,32 +31,80 @@ export class AggregateTool extends MongoDBToolBase {
2131
};
2232
public operationType: OperationType = "read";
2333

24-
protected async execute({
25-
database,
26-
collection,
27-
pipeline,
28-
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
29-
const provider = await this.ensureConnected();
34+
protected async execute(
35+
{ database, collection, pipeline, responseBytesLimit }: ToolArgs<typeof this.argsShape>,
36+
{ signal }: ToolExecutionContext
37+
): Promise<CallToolResult> {
38+
let aggregationCursor: AggregationCursor | undefined = undefined;
39+
try {
40+
const provider = await this.ensureConnected();
3041

31-
this.assertOnlyUsesPermittedStages(pipeline);
42+
this.assertOnlyUsesPermittedStages(pipeline);
3243

33-
// Check if aggregate operation uses an index if enabled
34-
if (this.config.indexCheck) {
35-
await checkIndexUsage(provider, database, collection, "aggregate", async () => {
36-
return provider
37-
.aggregate(database, collection, pipeline, {}, { writeConcern: undefined })
38-
.explain("queryPlanner");
39-
});
40-
}
44+
// Check if aggregate operation uses an index if enabled
45+
if (this.config.indexCheck) {
46+
await checkIndexUsage(provider, database, collection, "aggregate", async () => {
47+
return provider
48+
.aggregate(database, collection, pipeline, {}, { writeConcern: undefined })
49+
.explain("queryPlanner");
50+
});
51+
}
4152

42-
const documents = await provider.aggregate(database, collection, pipeline).toArray();
53+
const cappedResultsPipeline = [...pipeline];
54+
if (this.config.maxDocumentsPerQuery > 0) {
55+
cappedResultsPipeline.push({ $limit: this.config.maxDocumentsPerQuery });
56+
}
57+
aggregationCursor = provider.aggregate(database, collection, cappedResultsPipeline);
4358

44-
return {
45-
content: formatUntrustedData(
46-
`The aggregation resulted in ${documents.length} documents.`,
47-
documents.length > 0 ? EJSON.stringify(documents) : undefined
48-
),
49-
};
59+
const [totalDocuments, cursorResults] = await Promise.all([
60+
this.countAggregationResultDocuments({ provider, database, collection, pipeline }),
61+
collectCursorUntilMaxBytesLimit({
62+
cursor: aggregationCursor,
63+
configuredMaxBytesPerQuery: this.config.maxBytesPerQuery,
64+
toolResponseBytesLimit: responseBytesLimit,
65+
abortSignal: signal,
66+
}),
67+
]);
68+
69+
// If the total number of documents that the aggregation would've
70+
// resulted in would be greater than the configured
71+
// maxDocumentsPerQuery then we know for sure that the results were
72+
// capped.
73+
const aggregationResultsCappedByMaxDocumentsLimit =
74+
this.config.maxDocumentsPerQuery > 0 &&
75+
!!totalDocuments &&
76+
totalDocuments > this.config.maxDocumentsPerQuery;
77+
78+
return {
79+
content: formatUntrustedData(
80+
this.generateMessage({
81+
aggResultsCount: totalDocuments,
82+
documents: cursorResults.documents,
83+
appliedLimits: [
84+
aggregationResultsCappedByMaxDocumentsLimit ? "config.maxDocumentsPerQuery" : undefined,
85+
cursorResults.cappedBy,
86+
].filter((limit): limit is keyof typeof CURSOR_LIMITS_TO_LLM_TEXT => !!limit),
87+
}),
88+
cursorResults.documents.length > 0 ? EJSON.stringify(cursorResults.documents) : undefined
89+
),
90+
};
91+
} finally {
92+
if (aggregationCursor) {
93+
void this.safeCloseCursor(aggregationCursor);
94+
}
95+
}
96+
}
97+
98+
private async safeCloseCursor(cursor: AggregationCursor<unknown>): Promise<void> {
99+
try {
100+
await cursor.close();
101+
} catch (error) {
102+
this.session.logger.warning({
103+
id: LogId.mongodbCursorCloseError,
104+
context: "aggregate tool",
105+
message: `Error when closing the cursor - ${error instanceof Error ? error.message : String(error)}`,
106+
});
107+
}
50108
}
51109

52110
private assertOnlyUsesPermittedStages(pipeline: Record<string, unknown>[]): void {
@@ -70,4 +128,57 @@ export class AggregateTool extends MongoDBToolBase {
70128
}
71129
}
72130
}
131+
132+
private async countAggregationResultDocuments({
133+
provider,
134+
database,
135+
collection,
136+
pipeline,
137+
}: {
138+
provider: NodeDriverServiceProvider;
139+
database: string;
140+
collection: string;
141+
pipeline: Document[];
142+
}): Promise<number | undefined> {
143+
const resultsCountAggregation = [...pipeline, { $count: "totalDocuments" }];
144+
return await operationWithFallback(async (): Promise<number | undefined> => {
145+
const aggregationResults = await provider
146+
.aggregate(database, collection, resultsCountAggregation)
147+
.maxTimeMS(AGG_COUNT_MAX_TIME_MS_CAP)
148+
.toArray();
149+
150+
const documentWithCount: unknown = aggregationResults.length === 1 ? aggregationResults[0] : undefined;
151+
const totalDocuments =
152+
documentWithCount &&
153+
typeof documentWithCount === "object" &&
154+
"totalDocuments" in documentWithCount &&
155+
typeof documentWithCount.totalDocuments === "number"
156+
? documentWithCount.totalDocuments
157+
: 0;
158+
159+
return totalDocuments;
160+
}, undefined);
161+
}
162+
163+
private generateMessage({
164+
aggResultsCount,
165+
documents,
166+
appliedLimits,
167+
}: {
168+
aggResultsCount: number | undefined;
169+
documents: unknown[];
170+
appliedLimits: (keyof typeof CURSOR_LIMITS_TO_LLM_TEXT)[];
171+
}): string {
172+
const appliedLimitText = appliedLimits.length
173+
? `\
174+
while respecting the applied limits of ${appliedLimits.map((limit) => CURSOR_LIMITS_TO_LLM_TEXT[limit]).join(", ")}. \
175+
Note to LLM: If the entire query result is required then use "export" tool to export the query results.\
176+
`
177+
: "";
178+
179+
return `\
180+
The aggregation resulted in ${aggResultsCount === undefined ? "indeterminable number of" : aggResultsCount} documents. \
181+
Returning ${documents.length} documents${appliedLimitText ? ` ${appliedLimitText}` : "."}\
182+
`;
183+
}
73184
}

0 commit comments

Comments
 (0)