Skip to content

Commit 7b597cf

Browse files
authored
feat: support metadata post filter (#122)
* feat: support metadata post filter * feat: retrieve API support pass filters parameters * refine
1 parent 0a4db02 commit 7b597cf

File tree

16 files changed

+461
-218
lines changed

16 files changed

+461
-218
lines changed

ddl/0-initial-schema.sql

+1
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ CREATE TABLE retrieve_result
222222
document_id INT NOT NULL,
223223
document_chunk_node_id BINARY(16) NOT NULL,
224224
document_node_id BINARY(16) NOT NULL,
225+
document_metadata JSON NOT NULL,
225226
chunk_text TEXT NOT NULL,
226227
chunk_metadata JSON NOT NULL,
227228
PRIMARY KEY (id),

extensions/html-loader/HtmlLoader.ts

+4-6
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ import type {Element, Root} from 'hast';
66
import {select, selectAll} from 'hast-util-select';
77
import {toText} from 'hast-util-to-text';
88
import {match} from 'path-to-regexp';
9-
import rehypeParse from 'rehype-parse';
109
import {Processor, unified} from 'unified';
1110
import {remove} from 'unist-util-remove';
11+
import rehypeParse from 'rehype-parse';
1212
import htmlLoaderMeta, {
1313
DEFAULT_EXCLUDE_SELECTORS,
1414
DEFAULT_TEXT_SELECTORS,
1515
ExtractedMetadata,
16-
HtmlLoaderOptions,
16+
type HtmlLoaderOptions,
1717
MetadataExtractor,
1818
MetadataExtractorType,
1919
URLMetadataExtractor
@@ -40,10 +40,8 @@ export default class HtmlLoader extends rag.Loader<HtmlLoaderOptions, {}> {
4040
content: matchedTexts,
4141
hash: this.getTextHash(matchedTexts),
4242
metadata: {
43-
documentUrl: url,
44-
documentMetadata: {
45-
...metadataFromURL
46-
}
43+
url: url,
44+
...metadataFromURL
4745
},
4846
} satisfies rag.Content<{}>;
4947
}

src/app/api/v1/chats/route.ts

+19-26
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
import { type Chat, createChat, getChatByUrlKey, listChats } from '@/core/repositories/chat';
2-
import { ChatEngineOptions, getChatEngine, getDefaultChatEngine } from '@/core/repositories/chat_engine';
3-
import { getIndexByNameOrThrow } from '@/core/repositories/index_';
4-
import { LlamaindexChatService } from '@/core/services/llamaindex/chating';
5-
import { toPageRequest } from '@/lib/database';
6-
import { CHAT_CAN_NOT_ASSIGN_SESSION_ID_ERROR, CHAT_ENGINE_NOT_FOUND_ERROR } from '@/lib/errors';
7-
import { defineHandler } from '@/lib/next/handler';
8-
import { baseRegistry } from '@/rag-spec/base';
9-
import { getFlow } from '@/rag-spec/createFlow';
10-
import { notFound } from 'next/navigation';
11-
import { NextResponse } from 'next/server';
12-
import { z } from 'zod';
1+
import {type Chat, createChat, getChatByUrlKey, listChats} from '@/core/repositories/chat';
2+
import {getChatEngineConfig} from '@/core/repositories/chat_engine';
3+
import {getIndexByNameOrThrow} from '@/core/repositories/index_';
4+
import {LlamaindexChatService} from '@/core/services/llamaindex/chating';
5+
import {toPageRequest} from '@/lib/database';
6+
import {CHAT_CAN_NOT_ASSIGN_SESSION_ID_ERROR} from '@/lib/errors';
7+
import {defineHandler} from '@/lib/next/handler';
8+
import {baseRegistry} from '@/rag-spec/base';
9+
import {getFlow} from '@/rag-spec/createFlow';
10+
import {notFound} from 'next/navigation';
11+
import {NextResponse} from 'next/server';
12+
import {z} from 'zod';
1313

1414
const ChatRequest = z.object({
1515
messages: z.object({
@@ -50,12 +50,18 @@ export const POST = defineHandler({
5050
return CHAT_CAN_NOT_ASSIGN_SESSION_ID_ERROR;
5151
}
5252

53+
// TODO: using AI generated title.
54+
let title = body.name ?? DEFAULT_CHAT_TITLE;
55+
if (title.length > 255) {
56+
title = title.substring(0, 255);
57+
}
58+
5359
return await createChat({
5460
engine,
5561
engine_options: JSON.stringify(engineOptions),
5662
created_at: new Date(),
5763
created_by: userId,
58-
title: body.name ?? DEFAULT_CHAT_TITLE,
64+
title: title,
5965
});
6066
}
6167

@@ -96,19 +102,6 @@ export const POST = defineHandler({
96102
return chatStream.toResponse();
97103
});
98104

99-
async function getChatEngineConfig (engineConfigId?: number): Promise<[string, ChatEngineOptions]> {
100-
if (engineConfigId) {
101-
const chatEngine = await getChatEngine(engineConfigId);
102-
if (!chatEngine) {
103-
throw CHAT_ENGINE_NOT_FOUND_ERROR.format(engineConfigId);
104-
}
105-
return [chatEngine.engine, chatEngine.engine_options];
106-
} else {
107-
const config = await getDefaultChatEngine();
108-
return [config.engine, config.engine_options];
109-
}
110-
}
111-
112105
export const GET = defineHandler({
113106
auth: 'anonymous',
114107
searchParams: z.object({

src/app/api/v1/indexes/[name]/retrieve/route.ts

+26-13
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1-
import { LlamaindexRetrieveService } from '@/core/services/llamaindex/retrieving';
2-
import { retrieveOptionsSchema } from '@/core/services/retrieving';
1+
import {getChatEngineConfig} from "@/core/repositories/chat_engine";
32
import {getIndexByName} from '@/core/repositories/index_';
4-
import { defineHandler } from '@/lib/next/handler';
5-
import { baseRegistry } from '@/rag-spec/base';
6-
import { getFlow } from '@/rag-spec/createFlow';
7-
import { notFound } from 'next/navigation';
3+
import {LlamaindexRetrieveService} from '@/core/services/llamaindex/retrieving';
4+
import {retrieveOptionsSchema} from '@/core/services/retrieving';
5+
import {getEmbedding} from "@/lib/llamaindex/converters/embedding";
6+
import {getLLM} from "@/lib/llamaindex/converters/llm";
7+
import {defineHandler} from '@/lib/next/handler';
8+
import {baseRegistry} from '@/rag-spec/base';
9+
import {getFlow} from '@/rag-spec/createFlow';
10+
import {serviceContextFromDefaults} from "llamaindex";
11+
import {notFound} from 'next/navigation';
812
import z from 'zod';
913

1014
export const POST = defineHandler({
1115
params: z.object({
12-
name: z.string()
16+
name: z.string(),
1317
}),
1418
body: retrieveOptionsSchema,
1519
}, async ({
@@ -21,17 +25,26 @@ export const POST = defineHandler({
2125
notFound();
2226
}
2327

28+
const [engine, engineOptions] = await getChatEngineConfig(body.engine);
29+
const {
30+
llm: {
31+
provider: llmProvider = 'openai',
32+
config: llmConfig = {}
33+
} = {},
34+
} = engineOptions;
35+
2436
const flow = await getFlow(baseRegistry);
37+
const serviceContext = serviceContextFromDefaults({
38+
llm: getLLM(flow, llmProvider, llmConfig),
39+
embedModel: getEmbedding(flow, index.config.embedding.provider, index.config.embedding.config),
40+
});
2541

2642
const retrieveService = new LlamaindexRetrieveService({
27-
// TODO: support llm reranker
28-
reranker: {
29-
provider: 'cohere',
30-
options: {},
31-
},
43+
metadata_filter: engineOptions.metadata_filter,
44+
reranker: engineOptions.reranker,
3245
flow,
3346
index,
34-
serviceContext: {} as any
47+
serviceContext
3548
});
3649

3750
const result = await retrieveService.retrieve(body);

src/app/api/v1/tasks/temp_fill_document_metadata/route.ts

+31-39
Original file line numberDiff line numberDiff line change
@@ -12,66 +12,58 @@ export const GET = defineHandler({ auth: 'cronjob' }, async () => {
1212
const reader = fromFlowReaders(flow, index.config.reader);
1313

1414
await executeInSafeDuration(async () => {
15-
const documentWithoutMeta = await getDb().selectFrom('llamaindex_document_node')
15+
const documentWithoutMeta = await getDb()
16+
.selectFrom('llamaindex_document_node')
1617
.selectAll()
1718
.where((eb) => {
18-
return eb(
19-
eb.ref('metadata', '->$').key('documentMetadata' as never),
20-
'is',
21-
eb.val(null),
22-
)
19+
return eb.and([
20+
eb(eb.fn('JSON_UNQUOTE', [eb.ref('metadata')]), '=', eb.val('null')),
21+
])
2322
})
2423
.limit(100)
2524
.execute();
2625

2726
console.log(`Found ${documentWithoutMeta.length} documents without metadata.`)
2827

2928
const documentIds = Array.from(new Set(documentWithoutMeta.map(doc => doc.document_id)));
29+
if (documentIds.length == 0) {
30+
return false;
31+
}
32+
3033
const documents = await getDb()
3134
.selectFrom('document')
32-
.select('id')
33-
.select('content_uri')
34-
.select('source_uri')
35+
.select([
36+
'id', 'source_uri', 'content_uri'
37+
])
3538
.where('id', 'in', documentIds)
3639
.where('mime', '=', 'text/html')
3740
.execute();
3841

39-
console.log(`Found ${documents.length} documents to process.`)
40-
41-
if (documents.length == 0) {
42-
return false;
43-
}
42+
console.log(`Found ${documents.length} documents to process.`);
4443

4544
await Promise.all(documents.map(async document => {
46-
const docsWithMeta = await reader.loadData({
45+
return reader.loadData({
4746
mime: 'text/html',
4847
content_uri: document.content_uri,
4948
source_uri: document.source_uri,
49+
}).then((docsWithMeta) => {
50+
console.log(`Processing document ${document.id}.`)
51+
for (let docWithMeta of docsWithMeta) {
52+
getDb()
53+
.updateTable('llamaindex_document_node')
54+
.where('document_id', '=', document.id)
55+
.set(({ eb }) => ({
56+
metadata: eb.fn('JSON_MERGE_PATCH', [
57+
eb.ref('metadata'),
58+
eb.val(JSON.stringify(docWithMeta.metadata)),
59+
]),
60+
}))
61+
.execute()
62+
.then(null)
63+
}
64+
}).catch((e) => {
65+
console.error(`Failed to process document ${document.id}.`, e);
5066
});
51-
52-
console.log(`Processing document ${document.id}.`)
53-
54-
for (let docWithMeta of docsWithMeta) {
55-
await getDb().updateTable('llamaindex_document_node')
56-
.where('document_id', '=', document.id)
57-
.set(({ eb }) => ({
58-
metadata: eb.fn('JSON_MERGE_PATCH', [
59-
eb.ref('metadata'),
60-
eb.val(JSON.stringify(docWithMeta.metadata)),
61-
]),
62-
}))
63-
.execute();
64-
65-
await getDb().updateTable('llamaindex_document_chunk_node_default')
66-
.where('document_id', '=', document.id)
67-
.set(({ eb }) => ({
68-
metadata: eb.fn('JSON_MERGE_PATCH', [
69-
eb.ref('metadata'),
70-
eb.val(JSON.stringify(docWithMeta.metadata)),
71-
]),
72-
}))
73-
.execute();
74-
}
7567
}));
7668

7769
return true;

src/components/semantic-search.tsx

+4-2
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,16 @@ function InternalSearchBox () {
6060

6161
const disabled = loading || transitioning;
6262

63-
const search = (text: string) => {
63+
const search = (query: string) => {
6464
setLoading(true);
6565
startTransition(() => {
66+
// TODO: Allow using different indexes.
6667
fetch('/api/v1/indexes/default/retrieve', {
6768
method: 'post',
6869
body: JSON.stringify({
69-
text,
70+
query,
7071
top_k: 5,
72+
// TODO: Support metadata filters.
7173
}),
7274
}).then(handleErrors)
7375
.then(res => res.json())

src/core/db/schema.d.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import type { ColumnType } from "kysely";
1+
import type {ColumnType} from "kysely";
22

33
export type Generated<T> = T extends ColumnType<infer S, infer I, infer U>
44
? ColumnType<S, I | undefined, U>
@@ -165,6 +165,7 @@ export interface RetrieveResult {
165165
document_chunk_node_id: Buffer;
166166
document_id: number;
167167
document_node_id: Buffer;
168+
document_metadata: Json;
168169
id: Generated<number>;
169170
relevance_score: number;
170171
retrieve_id: number;

src/core/repositories/chat_engine.ts

+22-8
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
import type { RetrieveOptions } from '@/core/services/retrieving';
2-
import { type DBv1, getDb, tx } from '@/core/db';
3-
import { executePage, type PageRequest } from '@/lib/database';
4-
import { APIError } from '@/lib/errors';
5-
import type { Rewrite } from '@/lib/type-utils';
6-
import type { Insertable, Selectable, Updateable } from 'kysely';
7-
import { notFound } from 'next/navigation';
8-
import { z } from 'zod';
1+
import {type DBv1, getDb, tx} from '@/core/db';
2+
import type {RetrieveOptions} from '@/core/services/retrieving';
3+
import {executePage, type PageRequest} from '@/lib/database';
4+
import {APIError, CHAT_ENGINE_NOT_FOUND_ERROR} from '@/lib/errors';
5+
import type {Rewrite} from '@/lib/type-utils';
6+
import type {Insertable, Selectable, Updateable} from 'kysely';
7+
import {notFound} from 'next/navigation';
8+
import {z} from 'zod';
99

1010
export type ChatEngine = Rewrite<Selectable<DBv1['chat_engine']>, { engine_options: ChatEngineOptions }>;
1111
export type CreateChatEngine = Rewrite<Insertable<DBv1['chat_engine']>, { engine_options: ChatEngineOptions }>;
@@ -16,6 +16,7 @@ export type ChatEngineOptions = CondenseQuestionChatEngineOptions;
1616
export interface CondenseQuestionChatEngineOptions {
1717
index_id?: number;
1818
retriever?: Pick<RetrieveOptions, 'search_top_k' | 'top_k'>;
19+
metadata_filter?: { provider: string, config?: any };
1920
reranker?: { provider: string, config?: any };
2021
prompts?: {
2122
textQa?: string
@@ -67,6 +68,19 @@ export async function getDefaultChatEngine () {
6768
.executeTakeFirstOrThrow();
6869
}
6970

71+
export async function getChatEngineConfig (engineConfigId?: number): Promise<[string, ChatEngineOptions]> {
72+
if (engineConfigId) {
73+
const chatEngine = await getChatEngine(engineConfigId);
74+
if (!chatEngine) {
75+
throw CHAT_ENGINE_NOT_FOUND_ERROR.format(engineConfigId);
76+
}
77+
return [chatEngine.engine, chatEngine.engine_options];
78+
} else {
79+
const config = await getDefaultChatEngine();
80+
return [config.engine, config.engine_options];
81+
}
82+
}
83+
7084
export async function listChatEngine (request: PageRequest) {
7185
return await executePage(getDb()
7286
.selectFrom('chat_engine')

src/core/repositories/retrieve.ts

+13-11
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1-
import type { RetrieveOptions } from '@/core/services/retrieving';
2-
import { DBv1, getDb, tx } from '@/core/db';
3-
import { executePage, type PageRequest } from '@/lib/database';
4-
import { uuidToBin } from '@/lib/kysely';
5-
import type { Overwrite } from '@tanstack/table-core';
6-
import type { Insertable, Selectable, Updateable } from 'kysely';
7-
import type { UUID } from 'node:crypto';
1+
import {DBv1, getDb, tx} from '@/core/db';
2+
import {Json} from "@/core/db/schema";
3+
import type {RetrieveOptions} from '@/core/services/retrieving';
4+
import {executePage, type PageRequest} from '@/lib/database';
5+
import {uuidToBin} from '@/lib/kysely';
6+
import type {Overwrite} from '@tanstack/table-core';
7+
import type {Insertable, Selectable, Updateable} from 'kysely';
8+
import type {UUID} from 'node:crypto';
89

910
export type Retrieve = Overwrite<Selectable<DBv1['retrieve']>, { options: RetrieveOptions }>
1011
export type CreateRetrieve = Overwrite<Insertable<DBv1['retrieve']>, { options: RetrieveOptions }>
1112
export type UpdateRetrieve = Overwrite<Updateable<DBv1['retrieve']>, { options?: RetrieveOptions }>
12-
export type RetrieveResult = Overwrite<Selectable<DBv1['retrieve_result']>, { document_chunk_node_id: UUID, document_node_id: UUID }>
13-
export type CreateRetrieveResult = Overwrite<Insertable<DBv1['retrieve_result']>, { document_chunk_node_id: UUID, document_node_id: UUID }>
14-
export type UpdateRetrieveResult = Overwrite<Updateable<DBv1['retrieve_result']>, { document_chunk_node_id: UUID, document_node_id: UUID }>
13+
export type RetrieveResult = Overwrite<Selectable<DBv1['retrieve_result']>, { document_chunk_node_id: UUID, document_node_id: UUID, document_metadata: Json }>
14+
export type CreateRetrieveResult = Overwrite<Insertable<DBv1['retrieve_result']>, { document_chunk_node_id: UUID, document_node_id: UUID, document_metadata: Json }>
15+
export type UpdateRetrieveResult = Overwrite<Updateable<DBv1['retrieve_result']>, { document_chunk_node_id: UUID, document_node_id: UUID, document_metadata: Json }>
1516

1617
export async function getRetrieve (id: number) {
1718
return await getDb()
@@ -81,9 +82,10 @@ export async function finishRetrieve (id: number, reranked: boolean, results: Cr
8182
if (results.length > 0) {
8283
await getDb()
8384
.insertInto('retrieve_result')
84-
.values(results.map(({ document_node_id, document_chunk_node_id, ...rest }) => ({
85+
.values(results.map(({ document_node_id, document_chunk_node_id, document_metadata, ...rest }) => ({
8586
document_node_id: uuidToBin(document_node_id),
8687
document_chunk_node_id: uuidToBin(document_chunk_node_id),
88+
document_metadata: JSON.stringify(document_metadata),
8789
...rest,
8890
})))
8991
.execute();

0 commit comments

Comments
 (0)