Skip to content

Commit

Permalink
added a wrapper function for run()
Browse files Browse the repository at this point in the history
  • Loading branch information
star-nox committed Dec 8, 2023
1 parent 815dbee commit 1721bc2
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
9 changes: 8 additions & 1 deletion ai_ta_backend/filtering_contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,16 @@ def run(contexts, user_query, max_tokens_to_return=3000, max_time_before_return=
if not result_futures:
break


def ray_run(contexts, user_query, max_time_before_return=45, max_concurrency=100):
ray.init()
filtered_passages = list(run(contexts, user_query, max_time_before_return=max_time_before_return, max_concurrency=max_concurrency))
return filtered_passages


# ! CONDA ENV: llm-serving
if __name__ == "__main__":
ray.init()
#ray.init()
start_time = time.monotonic()
# print(len(CONTEXTS))

Expand Down
7 changes: 4 additions & 3 deletions ai_ta_backend/vector_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from ai_ta_backend.extreme_context_stuffing import OpenAIAPIProcessor
from ai_ta_backend.utils_tokenization import count_tokens_and_cost
from ai_ta_backend.parallel_context_processing import context_processing
from ai_ta_backend.filtering_contexts import run
from ai_ta_backend.filtering_contexts import run, ray_run


MULTI_QUERY_PROMPT = hub.pull("langchain-ai/rag-fusion-query-generation")
Expand Down Expand Up @@ -1164,8 +1164,9 @@ def getTopContextsWithMQR(self, search_query: str, course_name: str, token_limit
# count tokens at start and end, then also count each context.
token_counter, _ = count_tokens_and_cost(pre_prompt + '\n\nNow please respond to my query: ' + search_query) # type: ignore

ray.init()
filtered_docs = list(run(contexts=final_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100))
filtered_docs = ray_run(contexts=final_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100)

#filtered_docs = list(run(contexts=final_docs, user_query=search_query, max_time_before_return=45, max_concurrency=100))
print(f"Number of docs after context filtering: {len(filtered_docs)}")

valid_docs = []
Expand Down

0 comments on commit 1721bc2

Please sign in to comment.