diff --git a/ai_ta_backend/filtering_contexts.py b/ai_ta_backend/filtering_contexts.py index 6006b63e..98e71c71 100644 --- a/ai_ta_backend/filtering_contexts.py +++ b/ai_ta_backend/filtering_contexts.py @@ -13,6 +13,7 @@ from dotenv import load_dotenv from langchain import hub from langchain.prompts import PromptTemplate +#from openai import OpenAI from ai_ta_backend.utils_tokenization import count_tokens_and_cost @@ -33,10 +34,10 @@ def __init__(self): def filter_context(self, context, user_query, langsmith_prompt_obj): final_prompt = str(langsmith_prompt_obj.format(context=context, user_query=user_query)) - # print(f"-------\nfinal_prompt:\n{final_prompt}\n^^^^^^^^^^^^^") + print(f"-------\nfinal_prompt:\n{final_prompt}\n^^^^^^^^^^^^^") try: # completion = run_model(final_prompt) - # completion = run_replicate(final_prompt) + #completion = run_replicate(final_prompt) completion = run_anyscale(final_prompt) return {"completion": completion, "context": context} except Exception as e: @@ -81,11 +82,12 @@ def run_replicate(prompt): return output def run_anyscale(prompt): + print("in run anyscale") ret = openai.ChatCompletion.create( api_base = "https://api.endpoints.anyscale.com/v1", api_key=os.environ["ANYSCALE_ENDPOINT_TOKEN"], # model="meta-llama/Llama-2-70b-chat-hf", - model="mistralai/Mistral-7B-Instruct-v0.1", + engine="mistralai/Mistral-7B-Instruct-v0.1", messages=[{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}], temperature=0.3, @@ -94,6 +96,25 @@ def run_anyscale(prompt): print(ret["choices"][0]["message"]["content"]) return ret["choices"][0]["message"]["content"] +# def run_anyscale(prompt): +# print("in run anyscale") +# client = openai.OpenAI( +# base_url = "https://api.endpoints.anyscale.com/v1", +# api_key=os.environ["ANYSCALE_ENDPOINT_TOKEN"] +# ) + +# ret = client.chat.completions.create( +# # model="meta-llama/Llama-2-70b-chat-hf", +# model="mistralai/Mistral-7B-Instruct-v0.1", +# messages=[{"role": "system", "content": "You are a helpful assistant."}, +# {"role": "user", "content": prompt}], +# temperature=0.3, +# max_tokens=250, +# ) +# print("ANYSCALE RESPONSE: ", ret.choices[0].message.content) +# return ret.choices[0].message.content +# #return ret["choices"][0]["message"]["content"] + def parse_result(result): lines = result.split('\n') for line in lines: @@ -108,11 +129,15 @@ def run(contexts, user_query, max_tokens_to_return=3000, max_time_before_return= actor = AsyncActor.options(max_concurrency=max_concurrency).remote() result_futures = [actor.filter_context.remote(c, user_query, langsmith_prompt_obj) for c in contexts] + print("Num futures:", len(result_futures)) + #print("Result futures:", result_futures) + start_time = time.time() for i in range(0, len(result_futures)): try: ready, not_ready = ray.wait(result_futures) + print("ready:", ready) result = ray.get(ready[0]) if result is None: diff --git a/ai_ta_backend/vector_database.py b/ai_ta_backend/vector_database.py index 78eee5f9..f0e88219 100644 --- a/ai_ta_backend/vector_database.py +++ b/ai_ta_backend/vector_database.py @@ -43,6 +43,7 @@ from ai_ta_backend.extreme_context_stuffing import OpenAIAPIProcessor from ai_ta_backend.utils_tokenization import count_tokens_and_cost from ai_ta_backend.parallel_context_processing import context_processing +from ai_ta_backend.filtering_contexts import run MULTI_QUERY_PROMPT = hub.pull("langchain-ai/rag-fusion-query-generation") @@ -1396,13 +1397,13 @@ def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit # filled our token size, time to return break - # for v in valid_docs: - # print("FINAL VALID DOCS:") - # #print("valid doc text: ", v['text']) - # print("s3_path: ", v['s3_path']) - # print("url: ", v['url']) - # print("readable_filename: ", v['readable_filename']) - # print("\n") + print("Length of valid docs: ", len(valid_docs)) + + # insert filtering here and only pass relevant contexts ahead + final_filtered_docs = list(run(contexts=valid_docs, user_query=search_query, max_time_before_return=45, max_concurrency=20)) + + print("Length of final filtered docs: ", len(final_filtered_docs)) + #print("FINAL FILTERED DOCS: ", final_filtered_docs) print(f"Total tokens used: {token_counter} total docs: {len(found_docs)} num docs used: {len(valid_docs)}") print(f"Course: {course_name} ||| search_query: {search_query}") diff --git a/requirements.txt b/requirements.txt index e0f4823a..7f17ca4f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -52,3 +52,7 @@ unstructured==0.10.29 # causes huge ~5.3 GB of installs. Probbably from onnx: ht # Not currently supporting coursera ingest # cs-dlp @ git+https://github.com/raffaem/cs-dlp.git@0.12.0b0 # previously called coursera-dl +xlrd # for excel ingest +newrelic +ray +langchainhub \ No newline at end of file