Skip to content

Commit

Permalink
added openai type to all openai functions
Browse files Browse the repository at this point in the history
  • Loading branch information
star-nox committed Dec 6, 2023
1 parent 3e5a553 commit 256002e
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
5 changes: 4 additions & 1 deletion ai_ta_backend/filtering_contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def run_replicate(prompt):

def run_anyscale(prompt):
print("in run anyscale")

ret = openai.ChatCompletion.create(
api_base = "https://api.endpoints.anyscale.com/v1",
api_key=os.environ["ANYSCALE_ENDPOINT_TOKEN"],
Expand Down Expand Up @@ -126,9 +127,11 @@ def run(contexts, user_query, max_tokens_to_return=3000, max_time_before_return=
langsmith_prompt_obj = hub.pull("kastanday/filter-unrelated-contexts-zephyr")

print("Num jobs to run:", len(contexts))
#print("Context: ", contexts[0])
#exit()

actor = AsyncActor.options(max_concurrency=max_concurrency).remote()
result_futures = [actor.filter_context.remote(c, user_query, langsmith_prompt_obj) for c in contexts]
result_futures = [actor.filter_context.remote(c['text'], user_query, langsmith_prompt_obj) for c in contexts]
print("Num futures:", len(result_futures))
#print("Result futures:", result_futures)

Expand Down
4 changes: 2 additions & 2 deletions ai_ta_backend/nomic_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def log_convo_to_nomic(course_name: str, conversation) -> str:
}]

# create embeddings
embeddings_model = OpenAIEmbeddings() # type: ignore
embeddings_model = OpenAIEmbeddings(openai_api_type="azure") # type: ignore
embeddings = embeddings_model.embed_documents(user_queries)

# add embeddings to the project
Expand Down Expand Up @@ -279,7 +279,7 @@ def create_nomic_map(course_name: str, log_data: list):
metadata.append(metadata_row)

metadata = pd.DataFrame(metadata)
embeddings_model = OpenAIEmbeddings() # type: ignore
embeddings_model = OpenAIEmbeddings(openai_api_type="azure") # type: ignore
embeddings = embeddings_model.embed_documents(user_queries)

# create Atlas project
Expand Down
11 changes: 7 additions & 4 deletions ai_ta_backend/vector_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@


MULTI_QUERY_PROMPT = hub.pull("langchain-ai/rag-fusion-query-generation")
OPENAI_API_TYPE = "azure" # "openai" or "azure"

class Ingest():
"""
Expand All @@ -68,7 +69,7 @@ def __init__(self):
self.vectorstore = Qdrant(
client=self.qdrant_client,
collection_name=os.environ['QDRANT_COLLECTION_NAME'],
embeddings=OpenAIEmbeddings()) # type: ignore
embeddings=OpenAIEmbeddings(openai_api_type=OPENAI_API_TYPE)) # type: ignore

# S3
self.s3_client = boto3.client(
Expand All @@ -89,6 +90,7 @@ def __init__(self):
openai_api_key=os.getenv('AZURE_OPENAI_KEY'), #type:ignore
#openai_api_version=os.getenv('AZURE_OPENAI_API_VERSION'), #type:ignore
openai_api_version=os.getenv('OPENAI_API_VERSION'), #type:ignore
openai_api_type=OPENAI_API_TYPE
)
# self.llm = OpenAI(temperature=0, openai_api_base='https://api.kastan.ai/v1')
#self.llm = ChatOpenAI(temperature=0, model='gpt-3.5-turbo')
Expand Down Expand Up @@ -281,6 +283,7 @@ def _ingest_single_video(self, s3_path: str, course_name: str, **kwargs) -> str:
# check for file extension
file_ext = Path(s3_path).suffix
openai.api_key = os.getenv('OPENAI_API_KEY')

transcript_list = []
with NamedTemporaryFile(suffix=file_ext) as video_tmpfile:
# download from S3 into an video tmpfile
Expand Down Expand Up @@ -953,7 +956,7 @@ def getAll(

def vector_search(self, search_query, course_name):
top_n = 80
o = OpenAIEmbeddings() # type: ignore
o = OpenAIEmbeddings(openai_api_type=OPENAI_API_TYPE) # type: ignore
user_query_embedding = o.embed_query(search_query)
myfilter = models.Filter(
must=[
Expand Down Expand Up @@ -996,7 +999,7 @@ def vector_search(self, search_query, course_name):

def batch_vector_search(self, search_queries: List[str], course_name: str, top_n: int=20):
from qdrant_client.http import models as rest
o = OpenAIEmbeddings()
o = OpenAIEmbeddings(openai_api_type=OPENAI_API_TYPE)
# Prepare the filter for the course name
myfilter = rest.Filter(
must=[
Expand Down Expand Up @@ -1515,7 +1518,7 @@ def get_stuffed_prompt(self, search_query: str, course_name: str, token_limit: i
try:
top_n = 150
start_time_overall = time.monotonic()
o = OpenAIEmbeddings() # type: ignore
o = OpenAIEmbeddings(openai_api_type=OPENAI_API_TYPE) # type: ignore
user_query_embedding = o.embed_documents(search_query)[0] # type: ignore
myfilter = models.Filter(
must=[
Expand Down

0 comments on commit 256002e

Please sign in to comment.