Skip to content

Commit

Permalink
Update OpenAI API type to be fetched from environment variable
Browse files Browse the repository at this point in the history
  • Loading branch information
rohan-uiuc committed Mar 8, 2024
1 parent 81fc4ef commit 5951fe4
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 7 deletions.
4 changes: 2 additions & 2 deletions ai_ta_backend/service/nomic_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def log_convo_to_nomic(self, course_name: str, conversation) -> str | None:
}]

# create embeddings
embeddings_model = OpenAIEmbeddings(openai_api_type=OPENAI_API_TYPE) # type: ignore
embeddings_model = OpenAIEmbeddings(openai_api_type=os.environ['OPENAI_API_TYPE'])
embeddings = embeddings_model.embed_documents(user_queries)

# add embeddings to the project - create a new function for this
Expand Down Expand Up @@ -380,7 +380,7 @@ def create_nomic_map(self, course_name: str, log_data: list):
metadata.append(metadata_row)

metadata = pd.DataFrame(metadata)
embeddings_model = OpenAIEmbeddings(openai_api_type=OPENAI_API_TYPE) # type: ignore
embeddings_model = OpenAIEmbeddings(openai_api_type=os.environ['OPENAI_API_TYPE'])
embeddings = embeddings_model.embed_documents(user_queries)

# create Atlas project
Expand Down
8 changes: 3 additions & 5 deletions ai_ta_backend/service/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
from ai_ta_backend.service.sentry_service import SentryService
from ai_ta_backend.utils.utils_tokenization import count_tokens_and_cost

OPENAI_API_TYPE = "azure" # "openai" or "azure"


class RetrievalService:
"""
Expand All @@ -41,7 +39,7 @@ def __init__(self, vdb: VectorDatabase, sqlDb: SQLDatabase, aws: AWSStorage, pos
self.embeddings = OpenAIEmbeddings(
model='text-embedding-ada-002',
openai_api_base=os.getenv("AZURE_OPENAI_ENDPOINT"), # type:ignore
openai_api_type=OPENAI_API_TYPE,
openai_api_type=os.environ['OPENAI_API_TYPE'],
openai_api_key=os.getenv("AZURE_OPENAI_KEY"), # type:ignore
openai_api_version=os.getenv("OPENAI_API_VERSION"), # type:ignore
)
Expand All @@ -52,7 +50,7 @@ def __init__(self, vdb: VectorDatabase, sqlDb: SQLDatabase, aws: AWSStorage, pos
openai_api_base=os.getenv("AZURE_OPENAI_ENDPOINT"), # type:ignore
openai_api_key=os.getenv("AZURE_OPENAI_KEY"), # type:ignore
openai_api_version=os.getenv("OPENAI_API_VERSION"), # type:ignore
openai_api_type=OPENAI_API_TYPE,
openai_api_type=os.environ['OPENAI_API_TYPE'],
)

def getTopContexts(self, search_query: str, course_name: str, token_limit: int = 4_000) -> Union[List[Dict], str]:
Expand Down Expand Up @@ -336,7 +334,7 @@ def vector_search(self, search_query, course_name):
top_n = 80
# EMBED
openai_start_time = time.monotonic()
print("OPENAI_API_TYPE", OPENAI_API_TYPE)
print("OPENAI_API_TYPE", os.environ['OPENAI_API_TYPE'])
user_query_embedding = self.embeddings.embed_query(search_query)
openai_embedding_latency = time.monotonic() - openai_start_time

Expand Down

0 comments on commit 5951fe4

Please sign in to comment.