Skip to content

Commit

Permalink
merged context_filtering with MQR
Browse files Browse the repository at this point in the history
  • Loading branch information
star-nox committed Dec 5, 2023
1 parent b799444 commit 0f72b73
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 10 deletions.
31 changes: 28 additions & 3 deletions ai_ta_backend/filtering_contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dotenv import load_dotenv
from langchain import hub
from langchain.prompts import PromptTemplate
#from openai import OpenAI

from ai_ta_backend.utils_tokenization import count_tokens_and_cost

Expand All @@ -33,10 +34,10 @@ def __init__(self):

def filter_context(self, context, user_query, langsmith_prompt_obj):
final_prompt = str(langsmith_prompt_obj.format(context=context, user_query=user_query))
# print(f"-------\nfinal_prompt:\n{final_prompt}\n^^^^^^^^^^^^^")
print(f"-------\nfinal_prompt:\n{final_prompt}\n^^^^^^^^^^^^^")
try:
# completion = run_model(final_prompt)
# completion = run_replicate(final_prompt)
#completion = run_replicate(final_prompt)
completion = run_anyscale(final_prompt)
return {"completion": completion, "context": context}
except Exception as e:
Expand Down Expand Up @@ -81,11 +82,12 @@ def run_replicate(prompt):
return output

def run_anyscale(prompt):
print("in run anyscale")
ret = openai.ChatCompletion.create(
api_base = "https://api.endpoints.anyscale.com/v1",
api_key=os.environ["ANYSCALE_ENDPOINT_TOKEN"],
# model="meta-llama/Llama-2-70b-chat-hf",
model="mistralai/Mistral-7B-Instruct-v0.1",
engine="mistralai/Mistral-7B-Instruct-v0.1",
messages=[{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}],
temperature=0.3,
Expand All @@ -94,6 +96,25 @@ def run_anyscale(prompt):
print(ret["choices"][0]["message"]["content"])
return ret["choices"][0]["message"]["content"]

# def run_anyscale(prompt):
# print("in run anyscale")
# client = openai.OpenAI(
# base_url = "https://api.endpoints.anyscale.com/v1",
# api_key=os.environ["ANYSCALE_ENDPOINT_TOKEN"]
# )

# ret = client.chat.completions.create(
# # model="meta-llama/Llama-2-70b-chat-hf",
# model="mistralai/Mistral-7B-Instruct-v0.1",
# messages=[{"role": "system", "content": "You are a helpful assistant."},
# {"role": "user", "content": prompt}],
# temperature=0.3,
# max_tokens=250,
# )
# print("ANYSCALE RESPONSE: ", ret.choices[0].message.content)
# return ret.choices[0].message.content
# #return ret["choices"][0]["message"]["content"]

def parse_result(result):
lines = result.split('\n')
for line in lines:
Expand All @@ -108,11 +129,15 @@ def run(contexts, user_query, max_tokens_to_return=3000, max_time_before_return=

actor = AsyncActor.options(max_concurrency=max_concurrency).remote()
result_futures = [actor.filter_context.remote(c, user_query, langsmith_prompt_obj) for c in contexts]
print("Num futures:", len(result_futures))
#print("Result futures:", result_futures)


start_time = time.time()
for i in range(0, len(result_futures)):
try:
ready, not_ready = ray.wait(result_futures)
print("ready:", ready)
result = ray.get(ready[0])

if result is None:
Expand Down
15 changes: 8 additions & 7 deletions ai_ta_backend/vector_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from ai_ta_backend.extreme_context_stuffing import OpenAIAPIProcessor
from ai_ta_backend.utils_tokenization import count_tokens_and_cost
from ai_ta_backend.parallel_context_processing import context_processing
from ai_ta_backend.filtering_contexts import run


MULTI_QUERY_PROMPT = hub.pull("langchain-ai/rag-fusion-query-generation")
Expand Down Expand Up @@ -1396,13 +1397,13 @@ def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit
# filled our token size, time to return
break

# for v in valid_docs:
# print("FINAL VALID DOCS:")
# #print("valid doc text: ", v['text'])
# print("s3_path: ", v['s3_path'])
# print("url: ", v['url'])
# print("readable_filename: ", v['readable_filename'])
# print("\n")
print("Length of valid docs: ", len(valid_docs))

# insert filtering here and only pass relevant contexts ahead
final_filtered_docs = list(run(contexts=valid_docs, user_query=search_query, max_time_before_return=45, max_concurrency=20))

print("Length of final filtered docs: ", len(final_filtered_docs))
#print("FINAL FILTERED DOCS: ", final_filtered_docs)

print(f"Total tokens used: {token_counter} total docs: {len(found_docs)} num docs used: {len(valid_docs)}")
print(f"Course: {course_name} ||| search_query: {search_query}")
Expand Down
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,7 @@ unstructured==0.10.29 # causes huge ~5.3 GB of installs. Probbably from onnx: ht

# Not currently supporting coursera ingest
# cs-dlp @ git+https://github.com/raffaem/[email protected] # previously called coursera-dl
xlrd # for excel ingest
newrelic
ray
langchainhub

0 comments on commit 0f72b73

Please sign in to comment.