Skip to content

Commit 464c37a

Browse files
himanshusinghsCopilotgagik
authored
chore: adds validation for vector search stage's pre-filter expression MCP-242 (#696)
Co-authored-by: Copilot <[email protected]> Co-authored-by: gagik <[email protected]>
1 parent f56f772 commit 464c37a

File tree

10 files changed

+1219
-338
lines changed

10 files changed

+1219
-338
lines changed

src/common/logger.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ export const LogId = {
4949
toolUpdateFailure: mongoLogId(1_005_001),
5050
resourceUpdateFailure: mongoLogId(1_005_002),
5151
updateToolMetadata: mongoLogId(1_005_003),
52+
toolValidationError: mongoLogId(1_005_004),
5253

5354
streamableHttpTransportStarted: mongoLogId(1_006_001),
5455
streamableHttpTransportSessionCloseFailure: mongoLogId(1_006_002),

src/common/search/embeddingsProvider.ts

Lines changed: 6 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@ import { embedMany } from "ai";
44
import type { UserConfig } from "../config.js";
55
import assert from "assert";
66
import { createFetch } from "@mongodb-js/devtools-proxy-support";
7-
import { z } from "zod";
7+
import {
8+
type EmbeddingParameters,
9+
type VoyageEmbeddingParameters,
10+
type VoyageModels,
11+
zVoyageAPIParameters,
12+
} from "../../tools/mongodb/mongodbSchemas.js";
813

914
type EmbeddingsInput = string;
1015
type Embeddings = number[] | unknown[];
11-
export type EmbeddingParameters = {
12-
inputType: "query" | "document";
13-
};
1416

1517
export interface EmbeddingsProvider<
1618
SupportedModels extends string,
@@ -23,40 +25,6 @@ export interface EmbeddingsProvider<
2325
): Promise<Embeddings[]>;
2426
}
2527

26-
export const zVoyageModels = z
27-
.enum(["voyage-3-large", "voyage-3.5", "voyage-3.5-lite", "voyage-code-3"])
28-
.default("voyage-3-large");
29-
30-
// Zod does not undestand JS boxed numbers (like Int32) as integer literals,
31-
// so we preprocess them to unwrap them so Zod understands them.
32-
function unboxNumber(v: unknown): number {
33-
if (v && typeof v === "object" && typeof v.valueOf === "function") {
34-
const n = Number(v.valueOf());
35-
if (!Number.isNaN(n)) return n;
36-
}
37-
return v as number;
38-
}
39-
40-
export const zVoyageEmbeddingParameters = z.object({
41-
outputDimension: z
42-
.preprocess(
43-
unboxNumber,
44-
z.union([z.literal(256), z.literal(512), z.literal(1024), z.literal(2048), z.literal(4096)])
45-
)
46-
.optional()
47-
.default(1024),
48-
outputDtype: z.enum(["float", "int8", "uint8", "binary", "ubinary"]).optional().default("float"),
49-
});
50-
51-
const zVoyageAPIParameters = zVoyageEmbeddingParameters
52-
.extend({
53-
inputType: z.enum(["query", "document"]),
54-
})
55-
.strip();
56-
57-
type VoyageModels = z.infer<typeof zVoyageModels>;
58-
type VoyageEmbeddingParameters = z.infer<typeof zVoyageEmbeddingParameters> & EmbeddingParameters;
59-
6028
class VoyageEmbeddingsProvider implements EmbeddingsProvider<VoyageModels, VoyageEmbeddingParameters> {
6129
private readonly voyage: VoyageProvider;
6230

@@ -105,6 +73,3 @@ export function getEmbeddingsProvider(
10573

10674
return undefined;
10775
}
108-
109-
export const zSupportedEmbeddingParameters = zVoyageEmbeddingParameters.extend({ model: zVoyageModels });
110-
export type SupportedEmbeddingParameters = z.infer<typeof zSupportedEmbeddingParameters>;

src/common/search/vectorSearchEmbeddingsManager.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ import type { ConnectionManager } from "../connectionManager.js";
55
import z from "zod";
66
import { ErrorCodes, MongoDBError } from "../errors.js";
77
import { getEmbeddingsProvider } from "./embeddingsProvider.js";
8-
import type { EmbeddingParameters, SupportedEmbeddingParameters } from "./embeddingsProvider.js";
8+
import type { EmbeddingParameters } from "../../tools/mongodb/mongodbSchemas.js";
99
import { formatUntrustedData } from "../../tools/tool.js";
1010
import type { Similarity } from "../schemas.js";
11+
import type { SupportedEmbeddingParameters } from "../../tools/mongodb/mongodbSchemas.js";
1112

1213
export const quantizationEnum = z.enum(["none", "scalar", "binary"]);
1314
export type Quantization = z.infer<typeof quantizationEnum>;
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// Based on -
2+
3+
import type z from "zod";
4+
import { ErrorCodes, MongoDBError } from "../common/errors.js";
5+
import type { VectorSearchStage } from "../tools/mongodb/mongodbSchemas.js";
6+
import { type CompositeLogger, LogId } from "../common/logger.js";
7+
8+
// https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#mongodb-vector-search-pre-filter
9+
const ALLOWED_LOGICAL_OPERATORS = ["$not", "$nor", "$and", "$or"];
10+
11+
export type VectorSearchIndex = {
12+
name: string;
13+
latestDefinition: {
14+
fields: Array<
15+
| {
16+
type: "vector";
17+
}
18+
| {
19+
type: "filter";
20+
path: string;
21+
}
22+
>;
23+
};
24+
};
25+
26+
export function assertVectorSearchFilterFieldsAreIndexed({
27+
searchIndexes,
28+
pipeline,
29+
logger,
30+
}: {
31+
searchIndexes: VectorSearchIndex[];
32+
pipeline: Record<string, unknown>[];
33+
logger: CompositeLogger;
34+
}): void {
35+
const searchIndexesWithFilterFields = searchIndexes.reduce<Record<string, string[]>>(
36+
(indexFieldMap, searchIndex) => {
37+
const filterFields = searchIndex.latestDefinition.fields
38+
.map<string | undefined>((field) => {
39+
return field.type === "filter" ? field.path : undefined;
40+
})
41+
.filter((filterField) => filterField !== undefined);
42+
43+
indexFieldMap[searchIndex.name] = filterFields;
44+
return indexFieldMap;
45+
},
46+
{}
47+
);
48+
for (const stage of pipeline) {
49+
if ("$vectorSearch" in stage) {
50+
const { $vectorSearch: vectorSearchStage } = stage as z.infer<typeof VectorSearchStage>;
51+
const allowedFilterFields = searchIndexesWithFilterFields[vectorSearchStage.index];
52+
if (!allowedFilterFields) {
53+
logger.warning({
54+
id: LogId.toolValidationError,
55+
context: "aggregate tool",
56+
message: `Could not assert if filter fields are indexed - No filter fields found for index ${vectorSearchStage.index}`,
57+
});
58+
return;
59+
}
60+
61+
const filterFieldsInStage = collectFieldsFromVectorSearchFilter(vectorSearchStage.filter);
62+
const filterFieldsNotIndexed = filterFieldsInStage.filter((field) => !allowedFilterFields.includes(field));
63+
if (filterFieldsNotIndexed.length) {
64+
throw new MongoDBError(
65+
ErrorCodes.AtlasVectorSearchInvalidQuery,
66+
`Vector search stage contains filter on fields that are not indexed by index ${vectorSearchStage.index} - ${filterFieldsNotIndexed.join(", ")}`
67+
);
68+
}
69+
}
70+
}
71+
}
72+
73+
export function collectFieldsFromVectorSearchFilter(filter: unknown): string[] {
74+
if (!filter || typeof filter !== "object" || !Object.keys(filter).length) {
75+
return [];
76+
}
77+
78+
const collectedFields = Object.entries(filter).reduce<string[]>((collectedFields, [maybeField, fieldMQL]) => {
79+
if (ALLOWED_LOGICAL_OPERATORS.includes(maybeField) && Array.isArray(fieldMQL)) {
80+
return fieldMQL.flatMap((mql) => collectFieldsFromVectorSearchFilter(mql));
81+
}
82+
83+
if (!ALLOWED_LOGICAL_OPERATORS.includes(maybeField)) {
84+
collectedFields.push(maybeField);
85+
}
86+
return collectedFields;
87+
}, []);
88+
89+
return Array.from(new Set(collectedFields));
90+
}

src/tools/mongodb/create/insertMany.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js";
44
import { type ToolArgs, type OperationType, formatUntrustedData } from "../../tool.js";
55
import { zEJSON } from "../../args.js";
66
import { type Document } from "bson";
7-
import { zSupportedEmbeddingParameters } from "../../../common/search/embeddingsProvider.js";
7+
import { zSupportedEmbeddingParameters } from "../mongodbSchemas.js";
88
import { ErrorCodes, MongoDBError } from "../../../common/errors.js";
99

1010
const zSupportedEmbeddingParametersWithInput = zSupportedEmbeddingParameters.extend({
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import z from "zod";
2+
import { zEJSON } from "../args.js";
3+
4+
export const zVoyageModels = z
5+
.enum(["voyage-3-large", "voyage-3.5", "voyage-3.5-lite", "voyage-code-3"])
6+
.default("voyage-3-large");
7+
8+
// Zod does not undestand JS boxed numbers (like Int32) as integer literals,
9+
// so we preprocess them to unwrap them so Zod understands them.
10+
function unboxNumber(v: unknown): number {
11+
if (v && typeof v === "object" && typeof v.valueOf === "function") {
12+
const n = Number(v.valueOf());
13+
if (!Number.isNaN(n)) return n;
14+
}
15+
return v as number;
16+
}
17+
18+
export const zVoyageEmbeddingParameters = z.object({
19+
outputDimension: z
20+
.preprocess(
21+
unboxNumber,
22+
z.union([z.literal(256), z.literal(512), z.literal(1024), z.literal(2048), z.literal(4096)])
23+
)
24+
.optional()
25+
.default(1024),
26+
outputDtype: z.enum(["float", "int8", "uint8", "binary", "ubinary"]).optional().default("float"),
27+
});
28+
29+
export const zVoyageAPIParameters = zVoyageEmbeddingParameters
30+
.extend({
31+
inputType: z.enum(["query", "document"]),
32+
})
33+
.strip();
34+
35+
export type VoyageModels = z.infer<typeof zVoyageModels>;
36+
export type VoyageEmbeddingParameters = z.infer<typeof zVoyageEmbeddingParameters> & EmbeddingParameters;
37+
38+
export type EmbeddingParameters = {
39+
inputType: "query" | "document";
40+
};
41+
42+
export const zSupportedEmbeddingParameters = zVoyageEmbeddingParameters.extend({ model: zVoyageModels });
43+
export type SupportedEmbeddingParameters = z.infer<typeof zSupportedEmbeddingParameters>;
44+
45+
export const AnyVectorSearchStage = zEJSON();
46+
export const VectorSearchStage = z.object({
47+
$vectorSearch: z
48+
.object({
49+
exact: z
50+
.boolean()
51+
.optional()
52+
.default(false)
53+
.describe(
54+
"When true, uses an ENN algorithm, otherwise uses ANN. Using ENN is not compatible with numCandidates, in that case, numCandidates must be left empty."
55+
),
56+
index: z.string().describe("Name of the index, as retrieved from the `collection-indexes` tool."),
57+
path: z
58+
.string()
59+
.describe(
60+
"Field, in dot notation, where to search. There must be a vector search index for that field. Note to LLM: When unsure, use the 'collection-indexes' tool to validate that the field is indexed with a vector search index."
61+
),
62+
queryVector: z
63+
.union([z.string(), z.array(z.number())])
64+
.describe(
65+
"The content to search for. The embeddingParameters field is mandatory if the queryVector is a string, in that case, the tool generates the embedding automatically using the provided configuration."
66+
),
67+
numCandidates: z
68+
.number()
69+
.int()
70+
.positive()
71+
.optional()
72+
.describe("Number of candidates for the ANN algorithm. Mandatory when exact is false."),
73+
limit: z.number().int().positive().optional().default(10),
74+
filter: zEJSON()
75+
.optional()
76+
.describe(
77+
"MQL filter that can only use filter fields from the index definition. Note to LLM: If unsure, use the `collection-indexes` tool to learn which fields can be used for filtering."
78+
),
79+
embeddingParameters: zSupportedEmbeddingParameters
80+
.optional()
81+
.describe(
82+
"The embedding model and its parameters to use to generate embeddings before searching. It is mandatory if queryVector is a string value. Note to LLM: If unsure, ask the user before providing one."
83+
),
84+
})
85+
.passthrough(),
86+
});

src/tools/mongodb/read/aggregate.ts

Lines changed: 13 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -11,55 +11,15 @@ import { ErrorCodes, MongoDBError } from "../../../common/errors.js";
1111
import { collectCursorUntilMaxBytesLimit } from "../../../helpers/collectCursorUntilMaxBytes.js";
1212
import { operationWithFallback } from "../../../helpers/operationWithFallback.js";
1313
import { AGG_COUNT_MAX_TIME_MS_CAP, ONE_MB, CURSOR_LIMITS_TO_LLM_TEXT } from "../../../helpers/constants.js";
14-
import { zEJSON } from "../../args.js";
1514
import { LogId } from "../../../common/logger.js";
16-
import { zSupportedEmbeddingParameters } from "../../../common/search/embeddingsProvider.js";
17-
18-
const AnyStage = zEJSON();
19-
const VectorSearchStage = z.object({
20-
$vectorSearch: z
21-
.object({
22-
exact: z
23-
.boolean()
24-
.optional()
25-
.default(false)
26-
.describe(
27-
"When true, uses an ENN algorithm, otherwise uses ANN. Using ENN is not compatible with numCandidates, in that case, numCandidates must be left empty."
28-
),
29-
index: z.string().describe("Name of the index, as retrieved from the `collection-indexes` tool."),
30-
path: z
31-
.string()
32-
.describe(
33-
"Field, in dot notation, where to search. There must be a vector search index for that field. Note to LLM: When unsure, use the 'collection-indexes' tool to validate that the field is indexed with a vector search index."
34-
),
35-
queryVector: z
36-
.union([z.string(), z.array(z.number())])
37-
.describe(
38-
"The content to search for. The embeddingParameters field is mandatory if the queryVector is a string, in that case, the tool generates the embedding automatically using the provided configuration."
39-
),
40-
numCandidates: z
41-
.number()
42-
.int()
43-
.positive()
44-
.optional()
45-
.describe("Number of candidates for the ANN algorithm. Mandatory when exact is false."),
46-
limit: z.number().int().positive().optional().default(10),
47-
filter: zEJSON()
48-
.optional()
49-
.describe(
50-
"MQL filter that can only use filter fields from the index definition. Note to LLM: If unsure, use the `collection-indexes` tool to learn which fields can be used for filtering."
51-
),
52-
embeddingParameters: zSupportedEmbeddingParameters
53-
.optional()
54-
.describe(
55-
"The embedding model and its parameters to use to generate embeddings before searching. It is mandatory if queryVector is a string value. Note to LLM: If unsure, ask the user before providing one."
56-
),
57-
})
58-
.passthrough(),
59-
});
15+
import { AnyVectorSearchStage, VectorSearchStage } from "../mongodbSchemas.js";
16+
import {
17+
assertVectorSearchFilterFieldsAreIndexed,
18+
type VectorSearchIndex,
19+
} from "../../../helpers/assertVectorSearchFilterFieldsAreIndexed.js";
6020

6121
export const AggregateArgs = {
62-
pipeline: z.array(z.union([AnyStage, VectorSearchStage])).describe(
22+
pipeline: z.array(z.union([AnyVectorSearchStage, VectorSearchStage])).describe(
6323
`An array of aggregation stages to execute.
6424
\`$vectorSearch\` **MUST** be the first stage of the pipeline, or the first stage of a \`$unionWith\` subpipeline.
6525
### Usage Rules for \`$vectorSearch\`
@@ -97,6 +57,13 @@ export class AggregateTool extends MongoDBToolBase {
9757
try {
9858
const provider = await this.ensureConnected();
9959
await this.assertOnlyUsesPermittedStages(pipeline);
60+
if (await this.session.isSearchSupported()) {
61+
assertVectorSearchFilterFieldsAreIndexed({
62+
searchIndexes: (await provider.getSearchIndexes(database, collection)) as VectorSearchIndex[],
63+
pipeline,
64+
logger: this.session.logger,
65+
});
66+
}
10067

10168
// Check if aggregate operation uses an index if enabled
10269
if (this.config.indexCheck) {

0 commit comments

Comments
 (0)