Skip to content

Commit e266c58

Browse files
committed
chore: move helper into its own independent method
1 parent caf691a commit e266c58

File tree

5 files changed

+175
-158
lines changed

5 files changed

+175
-158
lines changed

src/common/search/embeddingsProvider.ts

Lines changed: 6 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@ import { embedMany } from "ai";
44
import type { UserConfig } from "../config.js";
55
import assert from "assert";
66
import { createFetch } from "@mongodb-js/devtools-proxy-support";
7-
import { z } from "zod";
7+
import {
8+
type EmbeddingParameters,
9+
type VoyageEmbeddingParameters,
10+
type VoyageModels,
11+
zVoyageAPIParameters,
12+
} from "../../tools/mongodb/mongodbSchemas.js";
813

914
type EmbeddingsInput = string;
1015
type Embeddings = number[] | unknown[];
11-
export type EmbeddingParameters = {
12-
inputType: "query" | "document";
13-
};
1416

1517
export interface EmbeddingsProvider<
1618
SupportedModels extends string,
@@ -23,40 +25,6 @@ export interface EmbeddingsProvider<
2325
): Promise<Embeddings[]>;
2426
}
2527

26-
export const zVoyageModels = z
27-
.enum(["voyage-3-large", "voyage-3.5", "voyage-3.5-lite", "voyage-code-3"])
28-
.default("voyage-3-large");
29-
30-
// Zod does not undestand JS boxed numbers (like Int32) as integer literals,
31-
// so we preprocess them to unwrap them so Zod understands them.
32-
function unboxNumber(v: unknown): number {
33-
if (v && typeof v === "object" && typeof v.valueOf === "function") {
34-
const n = Number(v.valueOf());
35-
if (!Number.isNaN(n)) return n;
36-
}
37-
return v as number;
38-
}
39-
40-
export const zVoyageEmbeddingParameters = z.object({
41-
outputDimension: z
42-
.preprocess(
43-
unboxNumber,
44-
z.union([z.literal(256), z.literal(512), z.literal(1024), z.literal(2048), z.literal(4096)])
45-
)
46-
.optional()
47-
.default(1024),
48-
outputDtype: z.enum(["float", "int8", "uint8", "binary", "ubinary"]).optional().default("float"),
49-
});
50-
51-
const zVoyageAPIParameters = zVoyageEmbeddingParameters
52-
.extend({
53-
inputType: z.enum(["query", "document"]),
54-
})
55-
.strip();
56-
57-
type VoyageModels = z.infer<typeof zVoyageModels>;
58-
type VoyageEmbeddingParameters = z.infer<typeof zVoyageEmbeddingParameters> & EmbeddingParameters;
59-
6028
class VoyageEmbeddingsProvider implements EmbeddingsProvider<VoyageModels, VoyageEmbeddingParameters> {
6129
private readonly voyage: VoyageProvider;
6230

@@ -105,6 +73,3 @@ export function getEmbeddingsProvider(
10573

10674
return undefined;
10775
}
108-
109-
export const zSupportedEmbeddingParameters = zVoyageEmbeddingParameters.extend({ model: zVoyageModels });
110-
export type SupportedEmbeddingParameters = z.infer<typeof zSupportedEmbeddingParameters>;

src/common/search/vectorSearchEmbeddingsManager.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ import type { ConnectionManager } from "../connectionManager.js";
55
import z from "zod";
66
import { ErrorCodes, MongoDBError } from "../errors.js";
77
import { getEmbeddingsProvider } from "./embeddingsProvider.js";
8-
import type { EmbeddingParameters, SupportedEmbeddingParameters } from "./embeddingsProvider.js";
8+
import type { EmbeddingParameters } from "../../tools/mongodb/mongodbSchemas.js";
99
import { formatUntrustedData } from "../../tools/tool.js";
1010
import type { Similarity } from "../schemas.js";
11+
import type { SupportedEmbeddingParameters } from "../../tools/mongodb/mongodbSchemas.js";
1112

1213
export const quantizationEnum = z.enum(["none", "scalar", "binary"]);
1314
export type Quantization = z.infer<typeof quantizationEnum>;

src/helpers/collectFieldsFromVectorSearchFilter.ts

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,75 @@
11
// Based on -
2+
3+
import type z from "zod";
4+
import { ErrorCodes, MongoDBError } from "../common/errors.js";
5+
import type { VectorSearchStage } from "../tools/mongodb/mongodbSchemas.js";
6+
import { type CompositeLogger, LogId } from "../common/logger.js";
7+
28
// https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#mongodb-vector-search-pre-filter
39
const ALLOWED_LOGICAL_OPERATORS = ["$not", "$nor", "$and", "$or"];
410

11+
export type VectorSearchIndex = {
12+
name: string;
13+
latestDefinition: {
14+
fields: Array<
15+
| {
16+
type: "vector";
17+
}
18+
| {
19+
type: "filter";
20+
path: string;
21+
}
22+
>;
23+
};
24+
};
25+
26+
export function assertVectorSearchFilterFieldsAreIndexed({
27+
searchIndexes,
28+
pipeline,
29+
logger,
30+
}: {
31+
searchIndexes: VectorSearchIndex[];
32+
pipeline: Record<string, unknown>[];
33+
logger: CompositeLogger;
34+
}): void {
35+
const searchIndexesWithFilterFields = searchIndexes.reduce<Record<string, string[]>>(
36+
(indexFieldMap, searchIndex) => {
37+
const filterFields = searchIndex.latestDefinition.fields
38+
.map<string | undefined>((field) => {
39+
return field.type === "filter" ? field.path : undefined;
40+
})
41+
.filter((filterField) => filterField !== undefined);
42+
43+
indexFieldMap[searchIndex.name] = filterFields;
44+
return indexFieldMap;
45+
},
46+
{}
47+
);
48+
for (const stage of pipeline) {
49+
if ("$vectorSearch" in stage) {
50+
const { $vectorSearch: vectorSearchStage } = stage as z.infer<typeof VectorSearchStage>;
51+
const allowedFilterFields = searchIndexesWithFilterFields[vectorSearchStage.index];
52+
if (!allowedFilterFields) {
53+
logger.warning({
54+
id: LogId.toolValidationError,
55+
context: "aggregate tool",
56+
message: `Could not assert if filter fields are indexed - No filter fields found for index ${vectorSearchStage.index}`,
57+
});
58+
return;
59+
}
60+
61+
const filterFieldsInStage = collectFieldsFromVectorSearchFilter(vectorSearchStage.filter);
62+
const filterFieldsNotIndexed = filterFieldsInStage.filter((field) => !allowedFilterFields.includes(field));
63+
if (filterFieldsNotIndexed.length) {
64+
throw new MongoDBError(
65+
ErrorCodes.AtlasVectorSearchInvalidQuery,
66+
`Vector search stage contains filter on fields that are not indexed by index ${vectorSearchStage.index} - ${filterFieldsNotIndexed.join(", ")}`
67+
);
68+
}
69+
}
70+
}
71+
}
72+
573
export function collectFieldsFromVectorSearchFilter(filter: unknown): string[] {
674
if (!filter || typeof filter !== "object" || !Object.keys(filter).length) {
775
return [];
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import z from "zod";
2+
import { zEJSON } from "../args.js";
3+
4+
export const zVoyageModels = z
5+
.enum(["voyage-3-large", "voyage-3.5", "voyage-3.5-lite", "voyage-code-3"])
6+
.default("voyage-3-large");
7+
8+
// Zod does not undestand JS boxed numbers (like Int32) as integer literals,
9+
// so we preprocess them to unwrap them so Zod understands them.
10+
function unboxNumber(v: unknown): number {
11+
if (v && typeof v === "object" && typeof v.valueOf === "function") {
12+
const n = Number(v.valueOf());
13+
if (!Number.isNaN(n)) return n;
14+
}
15+
return v as number;
16+
}
17+
18+
export const zVoyageEmbeddingParameters = z.object({
19+
outputDimension: z
20+
.preprocess(
21+
unboxNumber,
22+
z.union([z.literal(256), z.literal(512), z.literal(1024), z.literal(2048), z.literal(4096)])
23+
)
24+
.optional()
25+
.default(1024),
26+
outputDtype: z.enum(["float", "int8", "uint8", "binary", "ubinary"]).optional().default("float"),
27+
});
28+
29+
export const zVoyageAPIParameters = zVoyageEmbeddingParameters
30+
.extend({
31+
inputType: z.enum(["query", "document"]),
32+
})
33+
.strip();
34+
35+
export type VoyageModels = z.infer<typeof zVoyageModels>;
36+
export type VoyageEmbeddingParameters = z.infer<typeof zVoyageEmbeddingParameters> & EmbeddingParameters;
37+
38+
export type EmbeddingParameters = {
39+
inputType: "query" | "document";
40+
};
41+
42+
export const zSupportedEmbeddingParameters = zVoyageEmbeddingParameters.extend({ model: zVoyageModels });
43+
export type SupportedEmbeddingParameters = z.infer<typeof zSupportedEmbeddingParameters>;
44+
45+
export const AnyVectorSearchStage = zEJSON();
46+
export const VectorSearchStage = z.object({
47+
$vectorSearch: z
48+
.object({
49+
exact: z
50+
.boolean()
51+
.optional()
52+
.default(false)
53+
.describe(
54+
"When true, uses an ENN algorithm, otherwise uses ANN. Using ENN is not compatible with numCandidates, in that case, numCandidates must be left empty."
55+
),
56+
index: z.string().describe("Name of the index, as retrieved from the `collection-indexes` tool."),
57+
path: z
58+
.string()
59+
.describe(
60+
"Field, in dot notation, where to search. There must be a vector search index for that field. Note to LLM: When unsure, use the 'collection-indexes' tool to validate that the field is indexed with a vector search index."
61+
),
62+
queryVector: z
63+
.union([z.string(), z.array(z.number())])
64+
.describe(
65+
"The content to search for. The embeddingParameters field is mandatory if the queryVector is a string, in that case, the tool generates the embedding automatically using the provided configuration."
66+
),
67+
numCandidates: z
68+
.number()
69+
.int()
70+
.positive()
71+
.optional()
72+
.describe("Number of candidates for the ANN algorithm. Mandatory when exact is false."),
73+
limit: z.number().int().positive().optional().default(10),
74+
filter: zEJSON()
75+
.optional()
76+
.describe(
77+
"MQL filter that can only use filter fields from the index definition. Note to LLM: If unsure, use the `collection-indexes` tool to learn which fields can be used for filtering."
78+
),
79+
embeddingParameters: zSupportedEmbeddingParameters
80+
.optional()
81+
.describe(
82+
"The embedding model and its parameters to use to generate embeddings before searching. It is mandatory if queryVector is a string value. Note to LLM: If unsure, ask the user before providing one."
83+
),
84+
})
85+
.passthrough(),
86+
});

src/tools/mongodb/read/aggregate.ts

Lines changed: 13 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -11,56 +11,15 @@ import { ErrorCodes, MongoDBError } from "../../../common/errors.js";
1111
import { collectCursorUntilMaxBytesLimit } from "../../../helpers/collectCursorUntilMaxBytes.js";
1212
import { operationWithFallback } from "../../../helpers/operationWithFallback.js";
1313
import { AGG_COUNT_MAX_TIME_MS_CAP, ONE_MB, CURSOR_LIMITS_TO_LLM_TEXT } from "../../../helpers/constants.js";
14-
import { zEJSON } from "../../args.js";
1514
import { LogId } from "../../../common/logger.js";
16-
import { zSupportedEmbeddingParameters } from "../../../common/search/embeddingsProvider.js";
17-
import { collectFieldsFromVectorSearchFilter } from "../../../helpers/collectFieldsFromVectorSearchFilter.js";
18-
19-
const AnyStage = zEJSON();
20-
const VectorSearchStage = z.object({
21-
$vectorSearch: z
22-
.object({
23-
exact: z
24-
.boolean()
25-
.optional()
26-
.default(false)
27-
.describe(
28-
"When true, uses an ENN algorithm, otherwise uses ANN. Using ENN is not compatible with numCandidates, in that case, numCandidates must be left empty."
29-
),
30-
index: z.string().describe("Name of the index, as retrieved from the `collection-indexes` tool."),
31-
path: z
32-
.string()
33-
.describe(
34-
"Field, in dot notation, where to search. There must be a vector search index for that field. Note to LLM: When unsure, use the 'collection-indexes' tool to validate that the field is indexed with a vector search index."
35-
),
36-
queryVector: z
37-
.union([z.string(), z.array(z.number())])
38-
.describe(
39-
"The content to search for. The embeddingParameters field is mandatory if the queryVector is a string, in that case, the tool generates the embedding automatically using the provided configuration."
40-
),
41-
numCandidates: z
42-
.number()
43-
.int()
44-
.positive()
45-
.optional()
46-
.describe("Number of candidates for the ANN algorithm. Mandatory when exact is false."),
47-
limit: z.number().int().positive().optional().default(10),
48-
filter: zEJSON()
49-
.optional()
50-
.describe(
51-
"MQL filter that can only use filter fields from the index definition. Note to LLM: If unsure, use the `collection-indexes` tool to learn which fields can be used for filtering."
52-
),
53-
embeddingParameters: zSupportedEmbeddingParameters
54-
.optional()
55-
.describe(
56-
"The embedding model and its parameters to use to generate embeddings before searching. It is mandatory if queryVector is a string value. Note to LLM: If unsure, ask the user before providing one."
57-
),
58-
})
59-
.passthrough(),
60-
});
15+
import { AnyVectorSearchStage, VectorSearchStage } from "../mongodbSchemas.js";
16+
import {
17+
assertVectorSearchFilterFieldsAreIndexed,
18+
type VectorSearchIndex,
19+
} from "../../../helpers/collectFieldsFromVectorSearchFilter.js";
6120

6221
export const AggregateArgs = {
63-
pipeline: z.array(z.union([AnyStage, VectorSearchStage])).describe(
22+
pipeline: z.array(z.union([AnyVectorSearchStage, VectorSearchStage])).describe(
6423
`An array of aggregation stages to execute.
6524
\`$vectorSearch\` **MUST** be the first stage of the pipeline, or the first stage of a \`$unionWith\` subpipeline.
6625
### Usage Rules for \`$vectorSearch\`
@@ -98,7 +57,13 @@ export class AggregateTool extends MongoDBToolBase {
9857
try {
9958
const provider = await this.ensureConnected();
10059
await this.assertOnlyUsesPermittedStages(pipeline);
101-
await this.assertVectorSearchFilterFieldsAreIndexed(database, collection, pipeline);
60+
if (await this.session.isSearchSupported()) {
61+
assertVectorSearchFilterFieldsAreIndexed({
62+
searchIndexes: (await provider.getSearchIndexes(database, collection)) as VectorSearchIndex[],
63+
pipeline,
64+
logger: this.session.logger,
65+
});
66+
}
10267

10368
// Check if aggregate operation uses an index if enabled
10469
if (this.config.indexCheck) {
@@ -220,74 +185,6 @@ export class AggregateTool extends MongoDBToolBase {
220185
}
221186
}
222187

223-
private async assertVectorSearchFilterFieldsAreIndexed(
224-
database: string,
225-
collection: string,
226-
pipeline: Record<string, unknown>[]
227-
): Promise<void> {
228-
if (!(await this.session.isSearchSupported())) {
229-
return;
230-
}
231-
232-
const searchIndexesWithFilterFields = await this.searchIndexesWithFilterFields(database, collection);
233-
for (const stage of pipeline) {
234-
if ("$vectorSearch" in stage) {
235-
const { $vectorSearch: vectorSearchStage } = stage as z.infer<typeof VectorSearchStage>;
236-
const allowedFilterFields = searchIndexesWithFilterFields[vectorSearchStage.index];
237-
if (!allowedFilterFields) {
238-
this.session.logger.warning({
239-
id: LogId.toolValidationError,
240-
context: "aggregate tool",
241-
message: `Could not assert if filter fields are indexed - No filter fields found for index ${vectorSearchStage.index}`,
242-
});
243-
return;
244-
}
245-
246-
const filterFieldsInStage = collectFieldsFromVectorSearchFilter(vectorSearchStage.filter);
247-
const filterFieldsNotIndexed = filterFieldsInStage.filter(
248-
(field) => !allowedFilterFields.includes(field)
249-
);
250-
if (filterFieldsNotIndexed.length) {
251-
throw new MongoDBError(
252-
ErrorCodes.AtlasVectorSearchInvalidQuery,
253-
`Vector search stage contains filter on fields that are not indexed by index ${vectorSearchStage.index} - ${filterFieldsNotIndexed.join(", ")}`
254-
);
255-
}
256-
}
257-
}
258-
}
259-
260-
private async searchIndexesWithFilterFields(
261-
database: string,
262-
collection: string
263-
): Promise<Record<string, string[]>> {
264-
const searchIndexes = (await this.session.serviceProvider.getSearchIndexes(database, collection)) as Array<{
265-
name: string;
266-
latestDefinition: {
267-
fields: Array<
268-
| {
269-
type: "vector";
270-
}
271-
| {
272-
type: "filter";
273-
path: string;
274-
}
275-
>;
276-
};
277-
}>;
278-
279-
return searchIndexes.reduce<Record<string, string[]>>((indexFieldMap, searchIndex) => {
280-
const filterFields = searchIndex.latestDefinition.fields
281-
.map<string | undefined>((field) => {
282-
return field.type === "filter" ? field.path : undefined;
283-
})
284-
.filter((filterField) => filterField !== undefined);
285-
286-
indexFieldMap[searchIndex.name] = filterFields;
287-
return indexFieldMap;
288-
}, {});
289-
}
290-
291188
private async countAggregationResultDocuments({
292189
provider,
293190
database,

0 commit comments

Comments
 (0)