Skip to content

Commit

Permalink
anyscale LLM inference is faster than replicate or kastan.ai, 10 seco…
Browse files Browse the repository at this point in the history
…nds for 80 inference
  • Loading branch information
KastanDay committed Nov 7, 2023
1 parent 9aae9e6 commit fd99ebf
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions ai_ta_backend/filtering_contexts.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import replicate
# Env for kastan:

import inspect
import json
import os
import time
import traceback

import openai
import ray
import replicate
import requests
from dotenv import load_dotenv
from langchain import hub
Expand Down Expand Up @@ -33,7 +36,8 @@ def filter_context(self, context, user_query, langsmith_prompt_obj):
# print(f"-------\nfinal_prompt:\n{final_prompt}\n^^^^^^^^^^^^^")
try:
# completion = run_model(final_prompt)
completion = run_replicate(final_prompt)
# completion = run_replicate(final_prompt)
completion = run_anyscale(final_prompt)
return {"completion": completion, "context": context}
except Exception as e:
print(f"Error: {e}")
Expand Down Expand Up @@ -76,6 +80,20 @@ def run_replicate(prompt):
print(output)
return output

def run_anyscale(prompt):
ret = openai.ChatCompletion.create(
api_base = "https://api.endpoints.anyscale.com/v1",
api_key=os.environ["ANYSCALE_ENDPOINT_TOKEN"],
# model="meta-llama/Llama-2-70b-chat-hf",
model="mistralai/Mistral-7B-Instruct-v0.1",
messages=[{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}],
temperature=0.3,
max_tokens=250,
)
print(ret["choices"][0]["message"]["content"])
return ret["choices"][0]["message"]["content"]

def parse_result(result):
lines = result.split('\n')
for line in lines:
Expand Down Expand Up @@ -131,7 +149,8 @@ def run(contexts, user_query, max_tokens_to_return=3000, max_time_before_return=
if __name__ == "__main__":
ray.init()
start_time = time.monotonic()
final_passage_list = list(run(contexts=CONTEXTS[0:12], user_query=USER_QUERY, max_time_before_return=20, max_concurrency=22))
# print(len(CONTEXTS))
final_passage_list = list(run(contexts=CONTEXTS*2, user_query=USER_QUERY, max_time_before_return=45, max_concurrency=20))

print("✅✅✅ FINAL RESULTS: \n" + '\n'.join(json.dumps(r, indent=2) for r in final_passage_list))
print("✅✅✅ TOTAL RETURNED: ", len(final_passage_list))
Expand Down

0 comments on commit fd99ebf

Please sign in to comment.