Skip to content

Commit 080d840

Browse files
chore: add validation for vector stage pre-filter
1 parent 64ac05e commit 080d840

File tree

5 files changed

+633
-247
lines changed

5 files changed

+633
-247
lines changed

src/common/logger.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ export const LogId = {
4949
toolUpdateFailure: mongoLogId(1_005_001),
5050
resourceUpdateFailure: mongoLogId(1_005_002),
5151
updateToolMetadata: mongoLogId(1_005_003),
52+
toolValidationError: mongoLogId(1_005_004),
5253

5354
streamableHttpTransportStarted: mongoLogId(1_006_001),
5455
streamableHttpTransportSessionCloseFailure: mongoLogId(1_006_002),
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Based on -
2+
// https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#mongodb-vector-search-pre-filter
3+
const ALLOWED_LOGICAL_OPERATORS = ["$not", "$nor", "$and", "$or"];
4+
5+
export function collectFieldsFromVectorSearchFilter(filter: unknown): string[] {
6+
if (!filter || typeof filter !== "object" || !Object.keys(filter).length) {
7+
return [];
8+
}
9+
10+
const collectedFields = Object.entries(filter).reduce<string[]>((collectedFields, [maybeField, fieldMQL]) => {
11+
if (ALLOWED_LOGICAL_OPERATORS.includes(maybeField) && Array.isArray(fieldMQL)) {
12+
return fieldMQL.flatMap((mql) => collectFieldsFromVectorSearchFilter(mql));
13+
}
14+
15+
if (!ALLOWED_LOGICAL_OPERATORS.includes(maybeField)) {
16+
collectedFields.push(maybeField);
17+
}
18+
return collectedFields;
19+
}, []);
20+
21+
return Array.from(new Set(collectedFields));
22+
}

src/tools/mongodb/read/aggregate.ts

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import { AGG_COUNT_MAX_TIME_MS_CAP, ONE_MB, CURSOR_LIMITS_TO_LLM_TEXT } from "..
1414
import { zEJSON } from "../../args.js";
1515
import { LogId } from "../../../common/logger.js";
1616
import { zSupportedEmbeddingParameters } from "../../../common/search/embeddingsProvider.js";
17+
import { collectFieldsFromVectorSearchFilter } from "../../../helpers/collectFieldsFromVectorSearchFilter.js";
1718

1819
const AnyStage = zEJSON();
1920
const VectorSearchStage = z.object({
@@ -97,6 +98,7 @@ export class AggregateTool extends MongoDBToolBase {
9798
try {
9899
const provider = await this.ensureConnected();
99100
await this.assertOnlyUsesPermittedStages(pipeline);
101+
await this.assertVectorSearchFilterFieldsAreIndexed(database, collection, pipeline);
100102

101103
// Check if aggregate operation uses an index if enabled
102104
if (this.config.indexCheck) {
@@ -202,6 +204,74 @@ export class AggregateTool extends MongoDBToolBase {
202204
}
203205
}
204206

207+
private async assertVectorSearchFilterFieldsAreIndexed(
208+
database: string,
209+
collection: string,
210+
pipeline: Record<string, unknown>[]
211+
): Promise<void> {
212+
if (!(await this.session.isSearchSupported())) {
213+
return;
214+
}
215+
216+
const searchIndexesWithFilterFields = await this.searchIndexesWithFilterFields(database, collection);
217+
for (const stage of pipeline) {
218+
if ("$vectorSearch" in stage) {
219+
const { $vectorSearch: vectorSearchStage } = stage as z.infer<typeof VectorSearchStage>;
220+
const allowedFilterFields = searchIndexesWithFilterFields[vectorSearchStage.index];
221+
if (!allowedFilterFields) {
222+
this.session.logger.warning({
223+
id: LogId.toolValidationError,
224+
context: "aggregate tool",
225+
message: `Could not assert if filter fields are indexed - No filter fields found for index ${vectorSearchStage.index}`,
226+
});
227+
return;
228+
}
229+
230+
const filterFieldsInStage = collectFieldsFromVectorSearchFilter(vectorSearchStage.filter);
231+
const filterFieldsNotIndexed = filterFieldsInStage.filter(
232+
(field) => !allowedFilterFields.includes(field)
233+
);
234+
if (filterFieldsNotIndexed.length) {
235+
throw new MongoDBError(
236+
ErrorCodes.AtlasVectorSearchInvalidQuery,
237+
`Vector search stage contains filter on fields are not indexed by index ${vectorSearchStage.index} - ${filterFieldsNotIndexed.join(", ")}`
238+
);
239+
}
240+
}
241+
}
242+
}
243+
244+
private async searchIndexesWithFilterFields(
245+
database: string,
246+
collection: string
247+
): Promise<Record<string, string[]>> {
248+
const searchIndexes = (await this.session.serviceProvider.getSearchIndexes(database, collection)) as Array<{
249+
name: string;
250+
latestDefinition: {
251+
fields: Array<
252+
| {
253+
type: "vector";
254+
}
255+
| {
256+
type: "filter";
257+
path: string;
258+
}
259+
>;
260+
};
261+
}>;
262+
263+
return searchIndexes.reduce<Record<string, string[]>>((indexFieldMap, searchIndex) => {
264+
const filterFields = searchIndex.latestDefinition.fields
265+
.map<string | undefined>((field) => {
266+
return field.type === "filter" ? field.path : undefined;
267+
})
268+
.filter((filterField) => filterField !== undefined);
269+
270+
indexFieldMap[searchIndex.name] = filterFields;
271+
return indexFieldMap;
272+
}, {});
273+
}
274+
205275
private async countAggregationResultDocuments({
206276
provider,
207277
database,

0 commit comments

Comments
 (0)