diff --git a/examples/solara/arxiv-chat/ai.py b/examples/solara/arxiv-chat/ai.py index 5599fe4c..d344ced3 100644 --- a/examples/solara/arxiv-chat/ai.py +++ b/examples/solara/arxiv-chat/ai.py @@ -4,22 +4,46 @@ from IPython import display from datetime import datetime import tiktoken -import articles as art +from articles import ArxivClient from scipy.spatial import KDTree import numpy as np - +from datetime import datetime TOKEN_LIMIT = 3750 # allow some buffer so responses aren't cut off +PROMPT_MESSAGES = [ + "You are a system designed to give users information about scientific articles from Arxiv.", + "The user will start by entering a topic, or entering a URL to an Arxiv article.", + "If the user enters a topic, call the fetch_articles function.", + "If the user enters a URL, call the download_article_from_url function.", + "Only call one function at a time. Don't chain multiple function calls together.", +] + +def current_time(): + return datetime.now().strftime("%H:%M:%S") + +def print_msg(msg): + print(f"[{current_time()}]: {msg}") class OpenAIClient: def __init__(self): self.client = OpenAI() + self.store = EmbeddingsStore() self.encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") - self.messages = [] - self.articles = [] + self.messages = [None, ] + self.articles = [None for _ in range(6)] + self.article_chunks = [] + self.article_focus_id = None + self.load_messages() self.load_tools() self.load_categories() + + def load_messages(self): + for msg in list(PROMPT_MESSAGES): + self.messages.append({ + "role": "system", + "content": msg + }) def load_categories(self): @@ -38,7 +62,8 @@ def get_articles(self): articles = None with open(path, "r") as articles_json: articles = json.load(articles_json) - self.articles = articles + for i in range(5): + self.articles[i] = articles[i] return articles @@ -52,28 +77,24 @@ def count_tokens(self, msgs): def trim_messages(self, token_count): while token_count > TOKEN_LIMIT: - first_msg = self.messages.pop(0) + first_msg = self.messages.pop(1) token_count -= len(self.encoding.encode(str(first_msg))) + print_msg(f"Trimmed messages to: {token_count} tokens.") return token_count - def fetch_articles_from_query(self, query): - ac = art.ArxivClient() - store = EmbeddingsStore() - topic = self.topic_classify_categories(query) + def fetch_articles_from_query(self, query, criterion="relevance", order="descending"): + ac = ArxivClient() articles_raw = None - - if topic in self.categories: - articles_raw = ac.get_articles_by_cat(topic) + + topic = self.topic_classify_terms(query) + if len(topic.split()) > 10: + return False, topic else: - topic = self.topic_classify_terms(query) - if len(topic.split()) > 10: - return False, topic - else: - articles_raw = ac.get_articles_by_terms(topic) - + articles_raw = ac.get_articles_by_terms(topic, criterion, order) + articles = ac.results_to_array(articles_raw) - embeddings = store.get_many(articles) + embeddings = self.store.get_bunch(articles) try: kdtree = KDTree(np.array(embeddings)) @@ -81,28 +102,34 @@ def fetch_articles_from_query(self, query): help_msg = "There was a problem processing that message. Can you please try again? \n\n I can help you with a wide range of topics, including but not limited to: mathematics, computer science, astrophysics, statistics, and quantitative biology!" return False, help_msg - _, indexes = kdtree.query(store.get_one(query), k=5) + _, indexes = kdtree.query(self.store.get_one(query), k=5) relevant_articles = [articles_raw[i] for i in indexes] ac.results_to_json(relevant_articles) self.load_prompt() + return True, None def load_prompt(self, verbose=False): prompt = f""" You are a helpful assistant that can answer questions about scientific articles. -Here are the articles info in JSON format: +Here are the articles info in dictionary format: Use these to generate your answer, and disregard articles that are not relevant to answer the question: {str(self.get_articles())} + +Additional instructions: +1. Obey the rules set in each tool's description. +2. You will have a knowledge base of 5 articles and 1 article to focus on at a time. The article to focus on, or 'current article context', may or may not be in the main set of 5. +3. Assume the user is asking about the current article in focus. This is usually the paper that the user most recently asked about. """ if verbose: print(prompt) - self.messages.append({"role": "system", "content": prompt}) + self.messages[0] = {"role": "system", "content": prompt} def display_response(self, response): @@ -115,6 +142,34 @@ def display_response(self, response): {content}""")) return response + def call_fetch_articles_tool_for_query_params(self, query): + prompt = f""" +Given the query, call the fetch_articles function. +Do not return any text, just call the tool. Here is the query: + +{query} +""" + response = self.client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[{"role": "system", "content": prompt}], + tools=[self.tools[-1]], + seed=42, + n=1, + ) + + msg = dict(response.choices[0].message) + + if not msg.get("tool_calls"): + return "relevance", "descending" + + try: + args = msg["tool_calls"] + args = dict(eval(args[0].function.arguments)) + except: + return "relevance", "descending" + + return args["sort_criterion"], args["sort_order"] + def topic_classify_categories(self, user_query): system_prompt = f""" @@ -122,7 +177,10 @@ def topic_classify_categories(self, user_query): Given a user prompt, you should categorize it into an article category. Categories will be provided in a JSON dictionary format. Please return the category code, -which would be the key in the key, value pair. + +which would be the key in the key, value pair. + +Only return a code if the category is explicitly mentioned in the query. If you aren't sure, don't return a code. {self.categories} """ @@ -151,8 +209,9 @@ def topic_classify_terms(self, user_query): system_prompt = f""" You're a system that determines the topic of a question about academic articles in the field of Math and Science. -Given a user prompt, you should categorize it into a set of article search terms. -Keep it to a few essential terms. Here is a list of examples: +Given a user prompt, you should simplify it into a set of article search terms. +Your response should be less than 5 words. When in doubt, just return the query without filler words. Here is a list of examples: + {self.categories.values()} """ @@ -182,7 +241,7 @@ def topic_classify_terms(self, user_query): def get_article_by_title(self, title): for a in self.articles: - if a["title"] == title: + if a is not None and a["title"] == title: return a return None @@ -206,7 +265,6 @@ def get_authors(self, arguments): def get_links(self, arguments): try: article = self.get_article_by_title(arguments["title"]) - return article["links"] except: @@ -240,19 +298,77 @@ def get_categories(self, arguments): def fetch_articles(self, arguments): try: query = arguments["query"] - success, content = self.fetch_articles_from_query(query) + sort_criterion = arguments["sort_criterion"] + sort_order = arguments["sort_order"] + success, content = self.fetch_articles_from_query(query, sort_criterion, sort_order) + if not success: return content except: return f"There was a problem answering your question: {content}" return "FETCHED-NEED-SUMMARIZE" + + + def answer_question(self, arguments): + try: + id, query = arguments["id"], arguments["query"] + chunk = self._get_article_chunk(id, query) + prompt = f""" + Use this chunk of the article to answer the user's question. + {chunk} + + Here is the user's question: + {query} + """ + self.messages.append({ + "role": "system", + "content": prompt, + }) + + print_msg("Getting response from Open AI.") + response = self.client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[msg for msg in self.messages if msg is not None], + seed=42, + n=1, + ) + + answer = response.choices[0].message.content + print_msg(answer) + self.messages.append({"role": "assistant", "content": answer}) + return answer + except: + return f"There was a problem answering your question, try rephrasing." + + + def download_article_from_url(self, arguments): + url = arguments["url"] + id = url.split("/")[-1] + self._load_article_chunks(id) + message = f"Summarize the paper that was just provided in a few sentences based on this description: {self.articles[-1]['description']}. Don't call a function." + content = "" + for response in self.article_chat(message): + content = response + print_msg(f"Response from Open AI: {content}") + content = f"Got your article. Here's a summary: \n\n {content} \n\n Now I can answer your questions, so ask away!" + self.messages.append({ + "role": "assistant", + "content": content + }) + return content def call_tool(self, call): - func_name, args = call["name"], dict(eval(call["arguments"])) + print_msg(f"Function call: {call}") + try: + func_name, args = call["name"], dict(eval(call["arguments"])) + except: + return "There was a problem answering this question, try rephrasing." + content = getattr(self, func_name)(args) + self.messages.append( # adding assistant response to messages { "role": "assistant", @@ -263,7 +379,7 @@ def call_tool(self, call): "content": "" } ) - content = getattr(self, func_name)(args) + self.messages.append( # adding function response to messages { "role": "function", @@ -282,7 +398,7 @@ def article_chat(self, user_query): response = self.client.chat.completions.create( model="gpt-3.5-turbo", - messages=self.messages, + messages=[msg for msg in self.messages if msg is not None], tools=self.tools, tool_choice="auto", seed=42, @@ -313,6 +429,35 @@ def article_chat(self, user_query): yield answer self.messages.append({"role": "assistant", "content": answer}) + + + def _load_article_chunks(self, id=None): + if self.article_focus_id == id: + print_msg("Already downloaded this article.") + return + + print(f"Downloading articles, id: {id}") + info, chunks = ArxivClient().download_article(id) + self.article_focus_id = id + self.article_chunks = chunks + self.articles[-1] = info + self.messages.append({ + "role": "system", + "content": f"The current paper context is this title: {info['title']}. Unless the user mentions a different article, assume the user is asking about this article. You may call any of the tools." + }) + + + def _get_article_chunk(self, id=None, query=None): + self._load_article_chunks(id) + chunk_embeddings = self.store.get_bunch(self.article_chunks) + + query_embedding = self.store.get_one(query) + kdtree = KDTree(np.array(chunk_embeddings)) + + _, index = kdtree.query(query_embedding, k=1) + relevant_chunk = self.article_chunks[index] + + return relevant_chunk @@ -320,7 +465,6 @@ class EmbeddingsStore: def __init__(self): self._path = Path("./json/embeddings.json") - if not self._path.exists() or not self._path.read_text(): self._data = {} else: @@ -342,6 +486,25 @@ def get_one(self, text): return embedding + + def get_bunch(self, content): + print_msg("Getting many embeddings.") + try: + response = OpenAIClient().client.embeddings.create( + input=content, + model="text-embedding-3-small" + ) + except: + return self.get_many(content) + print_msg("Received response.") + embeddings = [item.embedding for item in response.data] + for i, text in enumerate(content): + self._data[text] = embeddings[i] + + self._path.write_text(json.dumps(self._data)) + print_msg(f"Returned {len(embeddings)} embeddings.") + return embeddings + def get_many(self, content): return [self.get_one(text) for text in content] diff --git a/examples/solara/arxiv-chat/articles.py b/examples/solara/arxiv-chat/articles.py index 5f27b859..c1e3e041 100644 --- a/examples/solara/arxiv-chat/articles.py +++ b/examples/solara/arxiv-chat/articles.py @@ -1,27 +1,90 @@ import arxiv from pathlib import Path import json +import fitz +import tiktoken + +MAX_CHUNK_SIZE = 1500 # measured in tokens + + class ArxivClient: def __init__(self): self.client = arxiv.Client() - self.results = None + self.encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") + + def token_length(self, text): + return len(self.encoding.encode(text)) - def _search(self, query): + def _search(self, query, criterion="relevance", order="descending"): + criterion_map = { + "relevance": arxiv.SortCriterion.Relevance, + "lastUpdatedDate": arxiv.SortCriterion.LastUpdatedDate, + "submittedDate": arxiv.SortCriterion.SubmittedDate, + } + order_map = { + "ascending": arxiv.SortOrder.Ascending, + "descending": arxiv.SortOrder.Descending, + } search = arxiv.Search( query=query, - max_results = 10, - sort_by = arxiv.SortCriterion.Relevance + max_results=10, + sort_by=criterion_map[criterion], + sort_order=order_map[order] ) + # print(f"Searching: {query}\nOrder by: {criterion} {order}") return search + def _search_by_id(self, id): + return list(self.client.results(arxiv.Search( + id_list=[id] + )))[0] + + def download_article(self, id=None): + result = self._search_by_id(id) + info = { + "id": result.get_short_id(), + "title": result.title, + "description": result.summary, + "published": str(result.published), + "authors": [a.name for a in result.authors], + "links": result.links[0].href, + "categories": result.categories, + } + + result.download_pdf(filename="article.pdf") + + doc = fitz.open("article.pdf") + # out = open("output.txt", "wb") + + print(f"Document length: {len(doc)}") + chunks = [] + curr_chunk = "" + + for p in doc: + t = p.get_text() + # out.write(t.encode("utf8")) # write text of pag + length = self.token_length(curr_chunk) + if length + self.token_length(t) <= MAX_CHUNK_SIZE: + curr_chunk += t + else: + # print(f"Length: {self.token_length(curr_chunk)}, Chunk: {curr_chunk[:15]}") + chunks.append(curr_chunk) + curr_chunk = t + + Path("article.pdf").unlink() + # out.close() + print("Downloaded file.") + return info, chunks + def get_articles_by_cat(self, query): query = f"cat:{query}" results = self.client.results(self._search(query)) return list(results) - def get_articles_by_terms(self, query): - results = self.client.results(self._search(query)) + + def get_articles_by_terms(self, query, criterion="relevance", order="descending"): + results = self.client.results(self._search(query, criterion, order)) return list(results) def results_to_json(self, results): @@ -29,6 +92,7 @@ def results_to_json(self, results): arr = [] for r in results: arr.append({ + "id": r.get_short_id(), "title": r.title, "description": r.summary, "published": str(r.published), diff --git a/examples/solara/arxiv-chat/chat.py b/examples/solara/arxiv-chat/chat.py index 296b8ccd..44e9a901 100644 --- a/examples/solara/arxiv-chat/chat.py +++ b/examples/solara/arxiv-chat/chat.py @@ -65,42 +65,26 @@ def Chat() -> None: Start by entering a math, science, or technology topic to learn about. \ Once I find you a set of articles, I can provide detailed information on each article including: \ author, description, category, published date, and download link.""" - ) + ), + Message( + role="assistant", + content="If you want to ask more detailed questions about an article, phrase them like \"In article 1, what is an LLM?\". If you provide a link to an article, I can also answer questions about it." + ), ]) - - loaded, set_loaded = sl.use_state(False) disabled, set_disabled = sl.use_state(False) - def load_articles_from_topic_query(query): - _messages = messages + [Message(role="user", content=query)] - set_messages(_messages + [Message(role="assistant", content="Processing...")]) - success, content = oc.fetch_articles_from_query(query) - - if not success: - set_messages(_messages + [Message(role="assistant", content=content)]) - return - - for new_message in oc.article_chat("Summarize each article in a sentence. Number them, and format like title: summary."): - set_messages(_messages + [Message(role="assistant", content=f"Fetched some articles.\n\n{new_message}")]) - - set_loaded(True) - def ask_chatgpt(input): set_disabled(True) _messages = messages + [Message(role="user", content=input)] set_messages(_messages) - if not loaded: - load_articles_from_topic_query(input) - set_disabled(False) - return for new_message in oc.article_chat(input): if new_message == "": set_messages(_messages + [Message(role="assistant", content="Processing...")]) elif new_message == "FETCHED-NEED-SUMMARIZE": - for msg in oc.article_chat("Summarize each article in a sentence. Number them, and format like title: summary."): + for msg in oc.article_chat("Summarize each article in a sentence. Number them and mention the title. Do not call any function."): set_messages(_messages + [Message(role="assistant", content=msg)]) else: diff --git a/examples/solara/arxiv-chat/json/categories.json b/examples/solara/arxiv-chat/json/categories.json index 1f28d6ca..a004ce4e 100644 --- a/examples/solara/arxiv-chat/json/categories.json +++ b/examples/solara/arxiv-chat/json/categories.json @@ -35,7 +35,6 @@ "cs.NA": "Numerical Analysis", "cs.NE": "Neural and Evolutionary Computing", "cs.NI": "Networking and Internet Architecture", - "cs.OH": "Other Computer Science", "cs.OS": "Operating Systems", "cs.PF": "Performance", "cs.PL": "Programming Languages", @@ -43,12 +42,10 @@ "cs.SC": "Symbolic Computation", "cs.SD": "Sound", "cs.SE": "Software Engineering", - "cs.SI": "Social and Information Networks", "cs.SY": "Systems and Control", "cond-mat.dis-nn": "Disordered Systems and Neural Networks", "cond-mat.mes-hall": "Mesoscale and Nanoscale Physics", "cond-mat.mtrl-sci": "Materials Science", - "cond-mat.other": "Other Condensed Matter", "cond-mat.quant-gas": "Quantum Gases", "cond-mat.soft": "Soft Condensed Matter", "cond-mat.stat-mech": "Statistical Mechanics", @@ -73,7 +70,6 @@ "physics.atm-clus": "Atomic and Molecular Clusters", "physics.atom-ph": "Atomic Physics", "physics.bio-ph": "Biological Physics", - "physics.chem-ph": "Chemical Physics", "physics.class-ph": "Classical Physics", "physics.comp-ph": "Computational Physics", "physics.data-an": "Data Analysis, Statistics and Probability", @@ -95,7 +91,6 @@ "q-bio.GN": "Genomics", "q-bio.MN": "Molecular Networks", "q-bio.NC": "Neurons and Cognition", - "q-bio.OT": "Other Quantitative Biology", "q-bio.PE": "Populations and Evolution", "q-bio.QM": "Quantitative Methods", "q-bio.SC": "Subcellular Processes", @@ -113,7 +108,6 @@ "stat.CO": "Computation", "stat.ME": "Methodology", "stat.ML": "Machine Learning", - "stat.OT": "Other Statistics", "stat.TH": "Statistics Theory", "gr-qc": "General Relativity And Quantum Cosmology", "hep-ex": "High Energy Physics - Experiment", diff --git a/examples/solara/arxiv-chat/json/tools.json b/examples/solara/arxiv-chat/json/tools.json index d9522b45..dbf5163d 100644 --- a/examples/solara/arxiv-chat/json/tools.json +++ b/examples/solara/arxiv-chat/json/tools.json @@ -4,7 +4,7 @@ "type": "function", "function": { "name": "get_description", - "description": "Return the detailed description of a given article", + "description": "Return the detailed description of a single article. Only called if the user specifically asks for the description. You may not call this function for several articles at a time.", "parameters": { "type": "object", "properties": { @@ -12,7 +12,8 @@ "type": "string", "description": "The title of the article" } - } + }, + "required": ["title"] } } }, @@ -28,7 +29,8 @@ "type": "string", "description": "The title of the article" } - } + }, + "required": ["title"] } } }, @@ -44,7 +46,8 @@ "type": "string", "description": "The title of the article" } - } + }, + "required": ["title"] } } }, @@ -60,7 +63,8 @@ "type": "string", "description": "The title of the article" } - } + }, + "required": ["title"] } } }, @@ -76,7 +80,46 @@ "type": "string", "description": "The title of the article" } - } + }, + "required": ["title"] + } + } + }, + { + "type": "function", + "function": { + "name": "answer_question", + "description": "If a user asks a question about a specific article that can't be answered with the other functions, call this function.", + "parameters": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "The ID of the article that the user is asking a question about" + }, + "query": { + "type": "string", + "description": "The user's question, verbatim. Don't alter it." + } + }, + "required": ["id", "query"] + } + } + }, + { + "type": "function", + "function": { + "name": "download_article_from_url", + "description": "Called if the user provides an arxiv url to a single article that they want to ask questions about. Example url: https://arxiv.org/abs/2402.00103. Cannot be called for multiple urls.", + "parameters": { + "type": "object", + "properties": { + "url": { + "type": "string", + "description": "The URL of the article to be downloaded. Do not modify the URL." + } + }, + "required": ["url"] } } }, @@ -84,15 +127,24 @@ "type": "function", "function": { "name": "fetch_articles", - "description": "Given a user query, fetch a new set of articles from Arxiv", + "description": "Given a user query, download a new set of articles from Arxiv. Do not call this function if you are asked to summarize articles.", "parameters": { "type": "object", "properties": { "query": { "type": "string", "description": "A query containing the topic that the user would like articles on" + }, + "sort_criterion": { + "type": "string", + "enum": ["relevance", "lastUpdatedDate", "submittedDate"] + }, + "sort_order": { + "type": "string", + "enum": ["ascending", "descending"] } - } + }, + "required": ["query", "sort_criterion", "sort_order"] } } } diff --git a/examples/solara/arxiv-chat/requirements.txt b/examples/solara/arxiv-chat/requirements.txt index b5787a94..07d33ca0 100644 --- a/examples/solara/arxiv-chat/requirements.txt +++ b/examples/solara/arxiv-chat/requirements.txt @@ -4,4 +4,5 @@ arxiv IPython scipy numpy -tiktoken \ No newline at end of file +tiktoken +pymupdf