Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arxiv Demo Improvements #119

Merged
merged 16 commits into from
Feb 23, 2024
229 changes: 196 additions & 33 deletions examples/solara/arxiv-chat/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,46 @@
from IPython import display
from datetime import datetime
import tiktoken
import articles as art
from articles import ArxivClient
from scipy.spatial import KDTree
import numpy as np

from datetime import datetime

TOKEN_LIMIT = 3750 # allow some buffer so responses aren't cut off

PROMPT_MESSAGES = [
"You are a system designed to give users information about scientific articles from Arxiv.",
"The user will start by entering a topic, or entering a URL to an Arxiv article.",
"If the user enters a topic, call the fetch_articles function.",
"If the user enters a URL, call the download_article_from_url function.",
"Only call one function at a time. Don't chain multiple function calls together.",
]

def current_time():
return datetime.now().strftime("%H:%M:%S")

def print_msg(msg):
print(f"[{current_time()}]: {msg}")

class OpenAIClient:
def __init__(self):
self.client = OpenAI()
self.store = EmbeddingsStore()
self.encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
self.messages = []
self.articles = []
self.messages = [None, ]
self.articles = [None for _ in range(6)]
self.article_chunks = []
self.article_focus_id = None
self.load_messages()
self.load_tools()
self.load_categories()

def load_messages(self):
for msg in list(PROMPT_MESSAGES):
self.messages.append({
"role": "system",
"content": msg
})


def load_categories(self):
Expand All @@ -38,7 +62,8 @@ def get_articles(self):
articles = None
with open(path, "r") as articles_json:
articles = json.load(articles_json)
self.articles = articles
for i in range(5):
self.articles[i] = articles[i]
return articles


Expand All @@ -52,57 +77,59 @@ def count_tokens(self, msgs):

def trim_messages(self, token_count):
while token_count > TOKEN_LIMIT:
first_msg = self.messages.pop(0)
first_msg = self.messages.pop(1)
token_count -= len(self.encoding.encode(str(first_msg)))
print_msg(f"Trimmed messages to: {token_count} tokens.")
return token_count


def fetch_articles_from_query(self, query):
ac = art.ArxivClient()
store = EmbeddingsStore()
topic = self.topic_classify_categories(query)
def fetch_articles_from_query(self, query, criterion="relevance", order="descending"):
ac = ArxivClient()
articles_raw = None

if topic in self.categories:
articles_raw = ac.get_articles_by_cat(topic)

topic = self.topic_classify_terms(query)
if len(topic.split()) > 10:
return False, topic
else:
topic = self.topic_classify_terms(query)
if len(topic.split()) > 10:
return False, topic
else:
articles_raw = ac.get_articles_by_terms(topic)

articles_raw = ac.get_articles_by_terms(topic, criterion, order)

articles = ac.results_to_array(articles_raw)
embeddings = store.get_many(articles)
embeddings = self.store.get_bunch(articles)

try:
kdtree = KDTree(np.array(embeddings))
except:
help_msg = "There was a problem processing that message. Can you please try again? \n\n I can help you with a wide range of topics, including but not limited to: mathematics, computer science, astrophysics, statistics, and quantitative biology!"
return False, help_msg

_, indexes = kdtree.query(store.get_one(query), k=5)
_, indexes = kdtree.query(self.store.get_one(query), k=5)
relevant_articles = [articles_raw[i] for i in indexes]

ac.results_to_json(relevant_articles)
self.load_prompt()

return True, None


def load_prompt(self, verbose=False):
prompt = f"""
You are a helpful assistant that can answer questions about scientific articles.
Here are the articles info in JSON format:
Here are the articles info in dictionary format:

Use these to generate your answer,
and disregard articles that are not relevant to answer the question:

{str(self.get_articles())}

Additional instructions:
1. Obey the rules set in each tool's description.
2. You will have a knowledge base of 5 articles and 1 article to focus on at a time. The article to focus on, or 'current article context', may or may not be in the main set of 5.
3. Assume the user is asking about the current article in focus. This is usually the paper that the user most recently asked about.
"""
if verbose:
print(prompt)

self.messages.append({"role": "system", "content": prompt})
self.messages[0] = {"role": "system", "content": prompt}


def display_response(self, response):
Expand All @@ -115,14 +142,45 @@ def display_response(self, response):
{content}"""))
return response

def call_fetch_articles_tool_for_query_params(self, query):
prompt = f"""
Given the query, call the fetch_articles function.
Do not return any text, just call the tool. Here is the query:

{query}
"""
response = self.client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "system", "content": prompt}],
tools=[self.tools[-1]],
seed=42,
n=1,
)

msg = dict(response.choices[0].message)

if not msg.get("tool_calls"):
return "relevance", "descending"

try:
args = msg["tool_calls"]
args = dict(eval(args[0].function.arguments))
except:
return "relevance", "descending"

return args["sort_criterion"], args["sort_order"]


def topic_classify_categories(self, user_query):
system_prompt = f"""
You're a system that determines the topic of a question about academic articles in the field of Math and Science.

Given a user prompt, you should categorize it into an article category.
Categories will be provided in a JSON dictionary format. Please return the category code,
which would be the key in the key, value pair.

which would be the key in the key, value pair.

Only return a code if the category is explicitly mentioned in the query. If you aren't sure, don't return a code.

{self.categories}
"""
Expand Down Expand Up @@ -151,8 +209,9 @@ def topic_classify_terms(self, user_query):
system_prompt = f"""
You're a system that determines the topic of a question about academic articles in the field of Math and Science.

Given a user prompt, you should categorize it into a set of article search terms.
Keep it to a few essential terms. Here is a list of examples:
Given a user prompt, you should simplify it into a set of article search terms.
Your response should be less than 5 words. When in doubt, just return the query without filler words. Here is a list of examples:


{self.categories.values()}
"""
Expand Down Expand Up @@ -182,7 +241,7 @@ def topic_classify_terms(self, user_query):

def get_article_by_title(self, title):
for a in self.articles:
if a["title"] == title:
if a is not None and a["title"] == title:
return a
return None

Expand All @@ -206,7 +265,6 @@ def get_authors(self, arguments):
def get_links(self, arguments):
try:
article = self.get_article_by_title(arguments["title"])

return article["links"]

except:
Expand Down Expand Up @@ -240,19 +298,77 @@ def get_categories(self, arguments):
def fetch_articles(self, arguments):
try:
query = arguments["query"]
success, content = self.fetch_articles_from_query(query)

sort_criterion = arguments["sort_criterion"]
sort_order = arguments["sort_order"]
success, content = self.fetch_articles_from_query(query, sort_criterion, sort_order)

if not success:
return content
except:
return f"There was a problem answering your question: {content}"

return "FETCHED-NEED-SUMMARIZE"


def answer_question(self, arguments):
try:
id, query = arguments["id"], arguments["query"]
chunk = self._get_article_chunk(id, query)
prompt = f"""
Use this chunk of the article to answer the user's question.
{chunk}

Here is the user's question:
{query}
"""
self.messages.append({
"role": "system",
"content": prompt,
})

print_msg("Getting response from Open AI.")
response = self.client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[msg for msg in self.messages if msg is not None],
seed=42,
n=1,
)

answer = response.choices[0].message.content
print_msg(answer)
self.messages.append({"role": "assistant", "content": answer})
return answer
except:
return f"There was a problem answering your question, try rephrasing."


def download_article_from_url(self, arguments):
url = arguments["url"]
id = url.split("/")[-1]
self._load_article_chunks(id)
message = f"Summarize the paper that was just provided in a few sentences based on this description: {self.articles[-1]['description']}. Don't call a function."
content = ""
for response in self.article_chat(message):
content = response
print_msg(f"Response from Open AI: {content}")
content = f"Got your article. Here's a summary: \n\n {content} \n\n Now I can answer your questions, so ask away!"
self.messages.append({
"role": "assistant",
"content": content
})
return content


def call_tool(self, call):
func_name, args = call["name"], dict(eval(call["arguments"]))
print_msg(f"Function call: {call}")
try:
func_name, args = call["name"], dict(eval(call["arguments"]))
except:
return "There was a problem answering this question, try rephrasing."

content = getattr(self, func_name)(args)

self.messages.append( # adding assistant response to messages
{
"role": "assistant",
Expand All @@ -263,7 +379,7 @@ def call_tool(self, call):
"content": ""
}
)
content = getattr(self, func_name)(args)

self.messages.append( # adding function response to messages
{
"role": "function",
Expand All @@ -282,7 +398,7 @@ def article_chat(self, user_query):

response = self.client.chat.completions.create(
model="gpt-3.5-turbo",
messages=self.messages,
messages=[msg for msg in self.messages if msg is not None],
tools=self.tools,
tool_choice="auto",
seed=42,
Expand Down Expand Up @@ -313,14 +429,42 @@ def article_chat(self, user_query):
yield answer

self.messages.append({"role": "assistant", "content": answer})


def _load_article_chunks(self, id=None):
if self.article_focus_id == id:
print_msg("Already downloaded this article.")
return

print(f"Downloading articles, id: {id}")
info, chunks = ArxivClient().download_article(id)
self.article_focus_id = id
self.article_chunks = chunks
self.articles[-1] = info
self.messages.append({
"role": "system",
"content": f"The current paper context is this title: {info['title']}. Unless the user mentions a different article, assume the user is asking about this article. You may call any of the tools."
})


def _get_article_chunk(self, id=None, query=None):
self._load_article_chunks(id)
chunk_embeddings = self.store.get_bunch(self.article_chunks)

query_embedding = self.store.get_one(query)
kdtree = KDTree(np.array(chunk_embeddings))

_, index = kdtree.query(query_embedding, k=1)
relevant_chunk = self.article_chunks[index]

return relevant_chunk



class EmbeddingsStore:
def __init__(self):
self._path = Path("./json/embeddings.json")


if not self._path.exists() or not self._path.read_text():
self._data = {}
else:
Expand All @@ -342,6 +486,25 @@ def get_one(self, text):

return embedding


def get_bunch(self, content):
print_msg("Getting many embeddings.")
try:
response = OpenAIClient().client.embeddings.create(
input=content,
model="text-embedding-3-small"
)
except:
return self.get_many(content)
print_msg("Received response.")
embeddings = [item.embedding for item in response.data]
for i, text in enumerate(content):
self._data[text] = embeddings[i]

self._path.write_text(json.dumps(self._data))
print_msg(f"Returned {len(embeddings)} embeddings.")
return embeddings


def get_many(self, content):
return [self.get_one(text) for text in content]
Expand Down
Loading