From 50768f59b851b7354c6377c154e1a2aa4943e8b5 Mon Sep 17 00:00:00 2001 From: Rohan Marwaha Date: Tue, 19 Dec 2023 14:48:49 -0600 Subject: [PATCH] Fixed regression issues 1. No message in retrieval API fixed 2. Caching and citation link generation improvement 3. Handle stream to shut it down gracefully on last chunk 4. Commented some debugging logs to keep the console clear --- src/components/Chat/Chat.tsx | 111 +++++++++++++++------------- src/components/Chat/ChatMessage.tsx | 2 +- src/utils/server/index.ts | 16 +++- 3 files changed, 74 insertions(+), 55 deletions(-) diff --git a/src/components/Chat/Chat.tsx b/src/components/Chat/Chat.tsx index cb29da0b0..5ce58d5e0 100644 --- a/src/components/Chat/Chat.tsx +++ b/src/components/Chat/Chat.tsx @@ -90,6 +90,7 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => { } const [inputContent, setInputContent] = useState('') + const [cacheMetrics, setCacheMetrics] = useState({ hits: 0, misses: 0 }); useEffect(() => { if (courseMetadata?.banner_image_s3 && courseMetadata.banner_image_s3 !== '') { @@ -235,34 +236,19 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => { return searchQuery; } - const handleContextSearch = async (message: Message, selectedConversation: Conversation) => { + const handleContextSearch = async (message: Message, selectedConversation: Conversation, searchQuery: string) => { if (getCurrentPageName() != 'gpt4') { // Extract text from all user messages in the conversation - const userMessagesText = selectedConversation.messages - .filter(msg => msg.role === 'user') //TODO: Remove this when we add message filtering/summarizing step to backend - .map(msg => { - if (typeof msg.content === 'string') { - return msg.content; - } else if (Array.isArray(msg.content)) { - // Concatenate all text contents - return msg.content - .filter(content => content.type === 'text') - .map(content => content.text) - .join(' '); - } - return ''; - }) - .join('\n'); // Join all user messages into a single string - - const token_limit = OpenAIModels[selectedConversation?.model.id as OpenAIModelID].tokenLimit; - await fetchContexts(getCurrentPageName(), userMessagesText, token_limit).then((curr_contexts) => { - message.contexts = curr_contexts as ContextWithMetadata[]; - }); + const token_limit = OpenAIModels[selectedConversation?.model.id as OpenAIModelID].tokenLimit + await fetchContexts(getCurrentPageName(), searchQuery, token_limit).then((curr_contexts) => { + message.contexts = curr_contexts as ContextWithMetadata[] + }) } } const generateCitationLink = async (context: ContextWithMetadata) => { - console.log('context: ', context); + // Uncomment for debugging + // console.log('context: ', context); if (context.url) { return context.url; } else if (context.s3_path) { @@ -272,16 +258,39 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => { } const getCitationLink = async (context: ContextWithMetadata, citationLinkCache: Map, citationIndex: number) => { + console.log("Generating citation link for context: ", citationIndex, context.readable_filename) const cachedLink = citationLinkCache.get(citationIndex); if (cachedLink) { + setCacheMetrics((prevMetrics) => { + const newMetrics = { ...prevMetrics, hits: prevMetrics.hits + 1 }; + // Uncomment for debugging + console.log(`Cache hit for citation index ${citationIndex}. Current cache hit ratio: ${(newMetrics.hits / (newMetrics.hits + newMetrics.misses)).toFixed(2)}`); + return newMetrics; + }); return cachedLink; } else { + setCacheMetrics((prevMetrics) => { + const newMetrics = { ...prevMetrics, misses: prevMetrics.misses + 1 }; + // Uncomment for debugging + console.log(`Cache miss for citation index ${citationIndex}. Current cache hit ratio: ${(newMetrics.hits / (newMetrics.hits + newMetrics.misses)).toFixed(2)}`); + return newMetrics; + }); const link = await generateCitationLink(context); citationLinkCache.set(citationIndex, link); return link; } } + const resetCacheMetrics = () => { + // console.log(`Final cache hit ratio for the message: ${(cacheMetrics.hits / (cacheMetrics.hits + cacheMetrics.misses)).toFixed(2)}`); + console.log(`Final Cache metrics: ${JSON.stringify(cacheMetrics)}`); + setCacheMetrics({ hits: 0, misses: 0 }); + } + + function escapeRegExp(string: string) { + return string.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); // $& means the whole matched string + } + // THIS IS WHERE MESSAGES ARE SENT. const handleSend = useCallback( async (message: Message, deleteCount = 0, plugin: Plugin | null = null) => { @@ -331,7 +340,7 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => { } // Run context search, attach to Message object. - await handleContextSearch(message, selectedConversation); + await handleContextSearch(message, selectedConversation, searchQuery); const chatBody: ChatBody = { model: updatedConversation.model, @@ -433,6 +442,7 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => { let done = false let isFirst = true let text = '' + const citationLinkCache = new Map(); try { while (!done) { if (stopConversationRef.current === true) { @@ -444,6 +454,7 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => { done = doneReading const chunkValue = decoder.decode(value) text += chunkValue + if (isFirst) { // isFirst refers to the first chunk of data received from the API (happens once for each new message from API) isFirst = false @@ -464,41 +475,38 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => { value: updatedConversation, }) } else { - const citationLinkCache = new Map(); - - function escapeRegExp(string: string) { - return string.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); // $& means the whole matched string - } const updatedMessagesPromises: Promise[] = updatedConversation.messages.map(async (message, index) => { if (index === updatedConversation.messages.length - 1 && message.contexts) { let content = text; - for (const context of message.contexts) { - // Extract page number from the content string - const pageMatch = content.match(new RegExp(`\\[${escapeRegExp(context.readable_filename)}, page: (\\d+)\\]\\(#\\)`)); - const pageNumber = pageMatch ? `#page=${pageMatch[1]}` : ''; - const citationIndex = message.contexts.indexOf(context) + 1; - const link = await getCitationLink(context, citationLinkCache, citationIndex); - - const citationLinkPattern = new RegExp(`\\[${citationIndex}\\](?!\\([^)]*\\))`, 'g'); - const citationLinkReplacement = `[${citationIndex}](${link}${pageNumber})`; - content = content.replace(citationLinkPattern, citationLinkReplacement); - - const filenameLinkPattern = new RegExp(`(\\b${citationIndex}\\.)\\s*\\[(.*?)\\]\\(\\#\\)`, 'g'); - - // The replacement pattern uses backreferences ($1 and $2) to keep the original citation index and the filename provided by OpenAI intact. - // $1 is the citation index and period, $2 is the filename provided by OpenAI. - const filenameLinkReplacement = `$1 [${context.readable_filename}](${link}${pageNumber})`; + // Identify all unique citation indices in the content + const citationIndices = new Set(); + const citationPattern = /\[(\d+)\](?!\([^)]*\))/g; + let match; + while ((match = citationPattern.exec(content)) !== null) { + citationIndices.add(parseInt(match[1] as string)); + } - // Perform the replacement - content = content.replace(filenameLinkPattern, (match, index, filename) => { - // Use the filename provided by OpenAI in the link text - return `${index} [${index} ${filename}](${link}${pageNumber})`; - }); + // Generate citation links only for the referenced indices + for (const citationIndex of citationIndices) { + const context = message.contexts[citationIndex - 1]; // Adjust index for zero-based array + if (context) { + const link = await getCitationLink(context, citationLinkCache, citationIndex); + const pageNumberMatch = content.match(new RegExp(`\\[${escapeRegExp(context.readable_filename)}, page: (\\d+)\\]\\(#\\)`)); + const pageNumber = pageNumberMatch ? `#page=${pageNumberMatch[1]}` : ''; + + // Replace citation index with link + content = content.replace(new RegExp(`\\[${citationIndex}\\](?!\\([^)]*\\))`, 'g'), `[${citationIndex}](${link}${pageNumber})`); + + // Replace filename with link + content = content.replace(new RegExp(`(\\b${citationIndex}\\.)\\s*\\[(.*?)\\]\\(\\#\\)`, 'g'), (match, index, filename) => { + return `${index} [${index} ${filename}](${link}${pageNumber})`; + }); + } } // Uncomment for debugging - console.log('content: ', content); + // console.log('content: ', content); return { ...message, content }; } return message; @@ -522,6 +530,9 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => { homeDispatch({ field: 'loading', value: false }); homeDispatch({ field: 'messageIsStreaming', value: false }); return; + } finally { + // Reset cache metrics after each message + resetCacheMetrics(); } if (!done) { diff --git a/src/components/Chat/ChatMessage.tsx b/src/components/Chat/ChatMessage.tsx index 5679b3b55..0794484db 100644 --- a/src/components/Chat/ChatMessage.tsx +++ b/src/components/Chat/ChatMessage.tsx @@ -552,7 +552,7 @@ export const ChatMessage: FC = memo( const { href, title } = props; // console.log("href:", href); // console.log("title:", title); - console.log("children:", children); + // console.log("children:", children); const isCitationLink = /^\d+$/.test(children[0] as string); if (isCitationLink) { return ( diff --git a/src/utils/server/index.ts b/src/utils/server/index.ts index 674522cd0..e879311c8 100644 --- a/src/utils/server/index.ts +++ b/src/utils/server/index.ts @@ -145,14 +145,13 @@ export const OpenAIStream = async ( if (stream) { console.log("Streaming response ") + let isStreamClosed = false; // Flag to track the state of the stream const apiStream = new ReadableStream({ async start(controller) { const onParse = (event: ParsedEvent | ReconnectInterval) => { if (event.type === 'event') { const data = event.data - let isStreamClosed = false; // Flag to track the state of the stream - try { // console.log('data: ', data) // ! DEBUGGING if (data.trim() !== "[DONE]") { @@ -185,8 +184,17 @@ export const OpenAIStream = async ( const parser = createParser(onParse) - for await (const chunk of res.body as any) { - parser.feed(decoder.decode(chunk)) + try { + for await (const chunk of res.body as any) { + if (!isStreamClosed) { // Only feed the parser if the stream is not closed + parser.feed(decoder.decode(chunk)) + } + } + } catch (e) { + if (!isStreamClosed) { + controller.error(e) + isStreamClosed = true; + } } }, })