Skip to content

Commit

Permalink
Fixed regression issues
Browse files Browse the repository at this point in the history
1. No message in retrieval API fixed
2. Caching and citation link generation improvement
3. Handle stream to shut it down gracefully on last chunk
4. Commented some debugging logs to keep the console clear
  • Loading branch information
rohan-uiuc committed Dec 19, 2023
1 parent 72f76ab commit 50768f5
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 55 deletions.
111 changes: 61 additions & 50 deletions src/components/Chat/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => {
}

const [inputContent, setInputContent] = useState<string>('')
const [cacheMetrics, setCacheMetrics] = useState({ hits: 0, misses: 0 });

useEffect(() => {
if (courseMetadata?.banner_image_s3 && courseMetadata.banner_image_s3 !== '') {
Expand Down Expand Up @@ -235,34 +236,19 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => {
return searchQuery;
}

const handleContextSearch = async (message: Message, selectedConversation: Conversation) => {
const handleContextSearch = async (message: Message, selectedConversation: Conversation, searchQuery: string) => {
if (getCurrentPageName() != 'gpt4') {
// Extract text from all user messages in the conversation
const userMessagesText = selectedConversation.messages
.filter(msg => msg.role === 'user') //TODO: Remove this when we add message filtering/summarizing step to backend
.map(msg => {
if (typeof msg.content === 'string') {
return msg.content;
} else if (Array.isArray(msg.content)) {
// Concatenate all text contents
return msg.content
.filter(content => content.type === 'text')
.map(content => content.text)
.join(' ');
}
return '';
})
.join('\n'); // Join all user messages into a single string

const token_limit = OpenAIModels[selectedConversation?.model.id as OpenAIModelID].tokenLimit;
await fetchContexts(getCurrentPageName(), userMessagesText, token_limit).then((curr_contexts) => {
message.contexts = curr_contexts as ContextWithMetadata[];
});
const token_limit = OpenAIModels[selectedConversation?.model.id as OpenAIModelID].tokenLimit
await fetchContexts(getCurrentPageName(), searchQuery, token_limit).then((curr_contexts) => {
message.contexts = curr_contexts as ContextWithMetadata[]
})
}
}

const generateCitationLink = async (context: ContextWithMetadata) => {
console.log('context: ', context);
// Uncomment for debugging
// console.log('context: ', context);
if (context.url) {
return context.url;
} else if (context.s3_path) {
Expand All @@ -272,16 +258,39 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => {
}

const getCitationLink = async (context: ContextWithMetadata, citationLinkCache: Map<number, string>, citationIndex: number) => {
console.log("Generating citation link for context: ", citationIndex, context.readable_filename)
const cachedLink = citationLinkCache.get(citationIndex);
if (cachedLink) {
setCacheMetrics((prevMetrics) => {
const newMetrics = { ...prevMetrics, hits: prevMetrics.hits + 1 };
// Uncomment for debugging
console.log(`Cache hit for citation index ${citationIndex}. Current cache hit ratio: ${(newMetrics.hits / (newMetrics.hits + newMetrics.misses)).toFixed(2)}`);
return newMetrics;
});
return cachedLink;
} else {
setCacheMetrics((prevMetrics) => {
const newMetrics = { ...prevMetrics, misses: prevMetrics.misses + 1 };
// Uncomment for debugging
console.log(`Cache miss for citation index ${citationIndex}. Current cache hit ratio: ${(newMetrics.hits / (newMetrics.hits + newMetrics.misses)).toFixed(2)}`);
return newMetrics;
});
const link = await generateCitationLink(context);
citationLinkCache.set(citationIndex, link);
return link;
}
}

const resetCacheMetrics = () => {
// console.log(`Final cache hit ratio for the message: ${(cacheMetrics.hits / (cacheMetrics.hits + cacheMetrics.misses)).toFixed(2)}`);
console.log(`Final Cache metrics: ${JSON.stringify(cacheMetrics)}`);
setCacheMetrics({ hits: 0, misses: 0 });
}

function escapeRegExp(string: string) {
return string.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); // $& means the whole matched string
}

// THIS IS WHERE MESSAGES ARE SENT.
const handleSend = useCallback(
async (message: Message, deleteCount = 0, plugin: Plugin | null = null) => {
Expand Down Expand Up @@ -331,7 +340,7 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => {
}

// Run context search, attach to Message object.
await handleContextSearch(message, selectedConversation);
await handleContextSearch(message, selectedConversation, searchQuery);

const chatBody: ChatBody = {
model: updatedConversation.model,
Expand Down Expand Up @@ -433,6 +442,7 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => {
let done = false
let isFirst = true
let text = ''
const citationLinkCache = new Map<number, string>();
try {
while (!done) {
if (stopConversationRef.current === true) {
Expand All @@ -444,6 +454,7 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => {
done = doneReading
const chunkValue = decoder.decode(value)
text += chunkValue

if (isFirst) {
// isFirst refers to the first chunk of data received from the API (happens once for each new message from API)
isFirst = false
Expand All @@ -464,41 +475,38 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => {
value: updatedConversation,
})
} else {
const citationLinkCache = new Map<number, string>();

function escapeRegExp(string: string) {
return string.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); // $& means the whole matched string
}

const updatedMessagesPromises: Promise<Message>[] = updatedConversation.messages.map(async (message, index) => {
if (index === updatedConversation.messages.length - 1 && message.contexts) {
let content = text;
for (const context of message.contexts) {
// Extract page number from the content string
const pageMatch = content.match(new RegExp(`\\[${escapeRegExp(context.readable_filename)}, page: (\\d+)\\]\\(#\\)`));
const pageNumber = pageMatch ? `#page=${pageMatch[1]}` : '';

const citationIndex = message.contexts.indexOf(context) + 1;
const link = await getCitationLink(context, citationLinkCache, citationIndex);

const citationLinkPattern = new RegExp(`\\[${citationIndex}\\](?!\\([^)]*\\))`, 'g');
const citationLinkReplacement = `[${citationIndex}](${link}${pageNumber})`;
content = content.replace(citationLinkPattern, citationLinkReplacement);

const filenameLinkPattern = new RegExp(`(\\b${citationIndex}\\.)\\s*\\[(.*?)\\]\\(\\#\\)`, 'g');

// The replacement pattern uses backreferences ($1 and $2) to keep the original citation index and the filename provided by OpenAI intact.
// $1 is the citation index and period, $2 is the filename provided by OpenAI.
const filenameLinkReplacement = `$1 [${context.readable_filename}](${link}${pageNumber})`;
// Identify all unique citation indices in the content
const citationIndices = new Set<number>();
const citationPattern = /\[(\d+)\](?!\([^)]*\))/g;
let match;
while ((match = citationPattern.exec(content)) !== null) {
citationIndices.add(parseInt(match[1] as string));
}

// Perform the replacement
content = content.replace(filenameLinkPattern, (match, index, filename) => {
// Use the filename provided by OpenAI in the link text
return `${index} [${index} ${filename}](${link}${pageNumber})`;
});
// Generate citation links only for the referenced indices
for (const citationIndex of citationIndices) {
const context = message.contexts[citationIndex - 1]; // Adjust index for zero-based array
if (context) {
const link = await getCitationLink(context, citationLinkCache, citationIndex);
const pageNumberMatch = content.match(new RegExp(`\\[${escapeRegExp(context.readable_filename)}, page: (\\d+)\\]\\(#\\)`));
const pageNumber = pageNumberMatch ? `#page=${pageNumberMatch[1]}` : '';

// Replace citation index with link
content = content.replace(new RegExp(`\\[${citationIndex}\\](?!\\([^)]*\\))`, 'g'), `[${citationIndex}](${link}${pageNumber})`);

// Replace filename with link
content = content.replace(new RegExp(`(\\b${citationIndex}\\.)\\s*\\[(.*?)\\]\\(\\#\\)`, 'g'), (match, index, filename) => {
return `${index} [${index} ${filename}](${link}${pageNumber})`;
});
}
}
// Uncomment for debugging
console.log('content: ', content);
// console.log('content: ', content);
return { ...message, content };
}
return message;
Expand All @@ -522,6 +530,9 @@ export const Chat = memo(({ stopConversationRef, courseMetadata }: Props) => {
homeDispatch({ field: 'loading', value: false });
homeDispatch({ field: 'messageIsStreaming', value: false });
return;
} finally {
// Reset cache metrics after each message
resetCacheMetrics();
}

if (!done) {
Expand Down
2 changes: 1 addition & 1 deletion src/components/Chat/ChatMessage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ export const ChatMessage: FC<Props> = memo(
const { href, title } = props;
// console.log("href:", href);
// console.log("title:", title);
console.log("children:", children);
// console.log("children:", children);
const isCitationLink = /^\d+$/.test(children[0] as string);
if (isCitationLink) {
return (
Expand Down
16 changes: 12 additions & 4 deletions src/utils/server/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,13 @@ export const OpenAIStream = async (

if (stream) {
console.log("Streaming response ")
let isStreamClosed = false; // Flag to track the state of the stream
const apiStream = new ReadableStream({
async start(controller) {
const onParse = (event: ParsedEvent | ReconnectInterval) => {
if (event.type === 'event') {
const data = event.data

let isStreamClosed = false; // Flag to track the state of the stream

try {
// console.log('data: ', data) // ! DEBUGGING
if (data.trim() !== "[DONE]") {
Expand Down Expand Up @@ -185,8 +184,17 @@ export const OpenAIStream = async (

const parser = createParser(onParse)

for await (const chunk of res.body as any) {
parser.feed(decoder.decode(chunk))
try {
for await (const chunk of res.body as any) {
if (!isStreamClosed) { // Only feed the parser if the stream is not closed
parser.feed(decoder.decode(chunk))
}
}
} catch (e) {
if (!isStreamClosed) {
controller.error(e)
isStreamClosed = true;
}
}
},
})
Expand Down

0 comments on commit 50768f5

Please sign in to comment.