From 24a39aefb86b6715b4db6c7a8375723dc15f85e7 Mon Sep 17 00:00:00 2001 From: Marcus Schiesser Date: Thu, 9 May 2024 15:16:34 +0800 Subject: [PATCH] feat: send retrieve start and end events (#827) --- .../core/src/callbacks/CallbackManager.ts | 4 ++ .../core/src/indices/vectorStore/index.ts | 38 +++++++++---------- packages/core/src/llm/types.ts | 8 ++++ 3 files changed, 31 insertions(+), 19 deletions(-) diff --git a/packages/core/src/callbacks/CallbackManager.ts b/packages/core/src/callbacks/CallbackManager.ts index b7a8eb33d3..3f8eae0d20 100644 --- a/packages/core/src/callbacks/CallbackManager.ts +++ b/packages/core/src/callbacks/CallbackManager.ts @@ -12,6 +12,8 @@ import type { LLMStreamEvent, LLMToolCallEvent, LLMToolResultEvent, + RetrievalEndEvent, + RetrievalStartEvent, } from "../llm/types.js"; export class LlamaIndexCustomEvent extends CustomEvent { @@ -45,6 +47,8 @@ export interface LlamaIndexEventMaps { * @deprecated */ retrieve: CustomEvent; + "retrieve-start": RetrievalStartEvent; + "retrieve-end": RetrievalEndEvent; /** * @deprecated */ diff --git a/packages/core/src/indices/vectorStore/index.ts b/packages/core/src/indices/vectorStore/index.ts index b81d897afc..2dc5922c89 100644 --- a/packages/core/src/indices/vectorStore/index.ts +++ b/packages/core/src/indices/vectorStore/index.ts @@ -1,14 +1,8 @@ -import type { - BaseNode, - Document, - Metadata, - NodeWithScore, -} from "../../Node.js"; +import type { BaseNode, Document, NodeWithScore } from "../../Node.js"; import { ImageNode, ObjectType, splitNodesByType } from "../../Node.js"; import type { BaseRetriever, RetrieveParams } from "../../Retriever.js"; import type { ServiceContext } from "../../ServiceContext.js"; import { - Settings, embedModelFromSettingsOrContext, nodeParserFromSettingsOrContext, } from "../../Settings.js"; @@ -25,6 +19,7 @@ import { createDocStoreStrategy, } from "../../ingestion/strategies/index.js"; import { wrapEventCaller } from "../../internal/context/EventCaller.js"; +import { getCallbackManager } from "../../internal/settings/CallbackManager.js"; import type { BaseNodePostprocessor } from "../../postprocessors/types.js"; import type { StorageContext } from "../../storage/StorageContext.js"; import { storageContextFromDefaults } from "../../storage/StorageContext.js"; @@ -411,10 +406,16 @@ export class VectorIndexRetriever implements BaseRetriever { this.imageSimilarityTopK = imageSimilarityTopK ?? DEFAULT_SIMILARITY_TOP_K; } + @wrapEventCaller async retrieve({ query, preFilters, }: RetrieveParams): Promise { + getCallbackManager().dispatchEvent("retrieve-start", { + payload: { + query, + }, + }); let nodesWithScores = await this.textRetrieve( query, preFilters as MetadataFilters, @@ -422,7 +423,17 @@ export class VectorIndexRetriever implements BaseRetriever { nodesWithScores = nodesWithScores.concat( await this.textToImageRetrieve(query, preFilters as MetadataFilters), ); - this.sendEvent(query, nodesWithScores); + getCallbackManager().dispatchEvent("retrieve-end", { + payload: { + query, + nodes: nodesWithScores, + }, + }); + // send deprecated event + getCallbackManager().dispatchEvent("retrieve", { + query, + nodes: nodesWithScores, + }); return nodesWithScores; } @@ -459,17 +470,6 @@ export class VectorIndexRetriever implements BaseRetriever { return this.buildNodeListFromQueryResult(result); } - @wrapEventCaller - protected sendEvent( - query: string, - nodesWithScores: NodeWithScore[], - ) { - Settings.callbackManager.dispatchEvent("retrieve", { - query, - nodes: nodesWithScores, - }); - } - protected async buildVectorStoreQuery( embedModel: BaseEmbedding, query: string, diff --git a/packages/core/src/llm/types.ts b/packages/core/src/llm/types.ts index c33291b82b..a2debd790b 100644 --- a/packages/core/src/llm/types.ts +++ b/packages/core/src/llm/types.ts @@ -1,7 +1,15 @@ import type { Tokenizers } from "../GlobalsHelper.js"; +import type { NodeWithScore } from "../Node.js"; import type { BaseEvent } from "../internal/type.js"; import type { BaseTool, JSONObject, ToolOutput, UUID } from "../types.js"; +export type RetrievalStartEvent = BaseEvent<{ + query: string; +}>; +export type RetrievalEndEvent = BaseEvent<{ + query: string; + nodes: NodeWithScore[]; +}>; export type LLMStartEvent = BaseEvent<{ id: UUID; messages: ChatMessage[];