Skip to content

Commit

Permalink
feat: implement filters for MongoDBAtlasVectorSearch (run-llama#1142)
Browse files Browse the repository at this point in the history
  • Loading branch information
thucpn authored Sep 5, 2024
1 parent e8f229c commit 11b3856
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 26 deletions.
5 changes: 5 additions & 0 deletions .changeset/red-vans-taste.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"llamaindex": patch
---

implement filters for MongoDBAtlasVectorSearch
13 changes: 11 additions & 2 deletions examples/mongodb/2_load_and_index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,23 @@ async function loadAndIndex() {
"full_text",
]);

const FILTER_METADATA_FIELD = "content_type";

documents.forEach((document, index) => {
const contentType = ["tweet", "post", "story"][index % 3]; // assign a random content type to each document
document.metadata = {
...document.metadata,
[FILTER_METADATA_FIELD]: contentType,
};
});

// create Atlas as a vector store
const vectorStore = new MongoDBAtlasVectorSearch({
mongodbClient: client,
dbName: databaseName,
collectionName: vectorCollectionName, // this is where your embeddings will be stored
indexName: indexName, // this is the name of the index you will need to create
indexedMetadataFields: [FILTER_METADATA_FIELD], // this is the field that will be used for the query
});

// now create an index from all the Documents and store them in Atlas
Expand All @@ -46,5 +57,3 @@ async function loadAndIndex() {
}

loadAndIndex().catch(console.error);

// you can't query your index yet because you need to create a vector search index in mongodb's UI now
14 changes: 13 additions & 1 deletion examples/mongodb/3_query.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,24 @@ async function query() {
dbName: process.env.MONGODB_DATABASE!,
collectionName: process.env.MONGODB_VECTORS!,
indexName: process.env.MONGODB_VECTOR_INDEX!,
indexedMetadataFields: ["content_type"],
});

const index = await VectorStoreIndex.fromVectorStore(store);

const retriever = index.asRetriever({ similarityTopK: 20 });
const queryEngine = index.asQueryEngine({ retriever });
const queryEngine = index.asQueryEngine({
retriever,
preFilters: {
filters: [
{
key: "content_type",
value: "story", // try "tweet" or "post" to see the difference
operator: "==",
},
],
},
});
const result = await queryEngine.query({
query: "What does author receive when he was 11 years old?", // Isaac Asimov's "Foundation" for Christmas
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,62 @@ import { getEnv } from "@llamaindex/env";
import type { BulkWriteOptions, Collection } from "mongodb";
import { MongoClient } from "mongodb";
import {
FilterCondition,
VectorStoreBase,
type FilterOperator,
type MetadataFilter,
type MetadataFilters,
type VectorStoreNoEmbedModel,
type VectorStoreQuery,
type VectorStoreQueryResult,
} from "./types.js";
import { metadataDictToNode, nodeToMetadata } from "./utils.js";

// Utility function to convert metadata filters to MongoDB filter
function toMongoDBFilter(
standardFilters: MetadataFilters,
): Record<string, any> {
const filters: Record<string, any> = {};
for (const filter of standardFilters?.filters ?? []) {
filters[filter.key] = filter.value;
// define your Atlas Search index. See detail https://www.mongodb.com/docs/atlas/atlas-search/field-types/knn-vector/
const DEFAULT_EMBEDDING_DEFINITION = {
type: "knnVector",
dimensions: 1536,
similarity: "cosine",
};

function mapLcMqlFilterOperators(operator: string): string {
const operatorMap: { [key in FilterOperator]?: string } = {
"==": "$eq",
"<": "$lt",
"<=": "$lte",
">": "$gt",
">=": "$gte",
"!=": "$ne",
in: "$in",
nin: "$nin",
};
const mqlOperator = operatorMap[operator as FilterOperator];
if (!mqlOperator) throw new Error(`Unsupported operator: ${operator}`);
return mqlOperator;
}

function toMongoDBFilter(filters?: MetadataFilters): Record<string, any> {
if (!filters) return {};

const createFilterObject = (mf: MetadataFilter) => ({
[mf.key]: {
[mapLcMqlFilterOperators(mf.operator)]: mf.value,
},
});

if (filters.filters.length === 1) {
return createFilterObject(filters.filters[0]);
}

if (filters.condition === FilterCondition.AND) {
return { $and: filters.filters.map(createFilterObject) };
}

if (filters.condition === FilterCondition.OR) {
return { $or: filters.filters.map(createFilterObject) };
}
return filters;

throw new Error("filters condition not recognized. Must be AND or OR");
}

/**
Expand All @@ -38,6 +77,8 @@ export class MongoDBAtlasVectorSearch
dbName: string;
collectionName: string;
autoCreateIndex: boolean;
embeddingDefinition: Record<string, unknown>;
indexedMetadataFields: string[];

/**
* The used MongoClient. If not given, a new MongoClient is created based on the MONGODB_URI env variable.
Expand Down Expand Up @@ -98,26 +139,14 @@ export class MongoDBAtlasVectorSearch
numCandidates: (query: VectorStoreQuery) => number;
private collection?: Collection;

// define your Atlas Search index. See detail https://www.mongodb.com/docs/atlas/atlas-search/field-types/knn-vector/
readonly SEARCH_INDEX_DEFINITION = {
mappings: {
dynamic: true,
fields: {
embedding: {
type: "knnVector",
dimensions: 1536,
similarity: "cosine",
},
},
},
};

constructor(
init: Partial<MongoDBAtlasVectorSearch> & {
dbName: string;
collectionName: string;
embedModel?: BaseEmbedding;
autoCreateIndex?: boolean;
indexedMetadataFields?: string[];
embeddingDefinition?: Record<string, unknown>;
},
) {
super(init.embedModel);
Expand All @@ -136,6 +165,11 @@ export class MongoDBAtlasVectorSearch
this.dbName = init.dbName ?? "default_db";
this.collectionName = init.collectionName ?? "default_collection";
this.autoCreateIndex = init.autoCreateIndex ?? true;
this.indexedMetadataFields = init.indexedMetadataFields ?? [];
this.embeddingDefinition = {
...DEFAULT_EMBEDDING_DEFINITION,
...(init.embeddingDefinition ?? {}),
};
this.indexName = init.indexName ?? "default";
this.embeddingKey = init.embeddingKey ?? "embedding";
this.idKey = init.idKey ?? "id";
Expand All @@ -161,9 +195,21 @@ export class MongoDBAtlasVectorSearch
(index) => index.name === this.indexName,
);
if (!indexExists) {
const additionalDefinition: Record<string, { type: string }> = {};
this.indexedMetadataFields.forEach((field) => {
additionalDefinition[field] = { type: "token" };
});
await this.collection.createSearchIndex({
name: this.indexName,
definition: this.SEARCH_INDEX_DEFINITION,
definition: {
mappings: {
dynamic: true,
fields: {
embedding: this.embeddingDefinition,
...additionalDefinition,
},
},
},
});
}
}
Expand All @@ -189,11 +235,18 @@ export class MongoDBAtlasVectorSearch
this.flatMetadata,
);

// Include the specified metadata fields in the top level of the document (to help filter)
const populatedMetadata: Record<string, unknown> = {};
for (const field of this.indexedMetadataFields) {
populatedMetadata[field] = metadata[field];
}

return {
[this.idKey]: node.id_,
[this.embeddingKey]: node.getEmbedding(),
[this.textKey]: node.getContent(MetadataMode.NONE) || "",
[this.metadataKey]: metadata,
...populatedMetadata,
};
});

Expand Down

0 comments on commit 11b3856

Please sign in to comment.