Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Insert bulk doc groups #264

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions ai_ta_backend/database/sql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Dict, List

import supabase
from injector import inject
Expand Down Expand Up @@ -123,3 +124,25 @@ def getConversation(self, course_name: str, key: str, value: str):
return self.supabase_client.table("llm-convo-monitor").select("*").eq(key, value).eq("course_name", course_name).execute()


def fetchDocumentsByURLs(self, urls: List[str], course_name: str, page: int = 1, items_per_page: int = 1500):
"""
Fetch documents that have base_url matching any of the URLs in the provided list.
"""
return self.supabase_client.table("documents").select("id, readable_filename, url, s3_path, base_url, doc_groups(name)").in_("base_url", urls).eq("course_name", course_name).range((page - 1) * items_per_page, page * items_per_page - 1).execute()

def insertDocumentGroupsBulk(self, document_group):
# Assuming the Supabase client's insert method supports returning inserted records
inserted_records = self.supabase_client.table("doc_groups").upsert(document_group, on_conflict="name, course_name", ignore_duplicates=False).execute()
print(f"Inserted document groups: {inserted_records.data}")
# Extract and return the IDs of the inserted document groups
inserted_ids = inserted_records.data[0]['id']
return inserted_ids

def updateDocumentsDocGroupsBulk(self, document_ids: List[int], doc_group_id: int):
# Prepare updates
updates = [{"document_id": doc_id, "doc_group_id": doc_group_id} for doc_id in document_ids]
# Perform bulk update
self.supabase_client.table("documents_doc_groups").upsert(updates,on_conflict="document_id, doc_group_id", ignore_duplicates=True).execute()



36 changes: 36 additions & 0 deletions ai_ta_backend/database/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Qdrant
from qdrant_client import QdrantClient, models
from qdrant_client.conversions.common_types import WriteOrdering

OPENAI_API_TYPE = "azure" # "openai" or "azure"

Expand Down Expand Up @@ -86,3 +87,38 @@ def delete_data(self, collection_name: str, key: str, value: str):
),
]),
)

def add_document_groups_to_documents(self, course_name: str, documents: List[dict], doc_group_name: str):
"""
Add document groups to documents in the vector database.
"""
print(f"Adding document groups: {doc_group_name} to documents in the vector database for course: {course_name} and {len(documents)} documents.")
update_operations = []
for document in documents:
# print(f"Adding document groups to document: {document} ")
key = "url" if "url" in document else "s3_path"
value = models.MatchValue(value=document[key])
searchFilter = models.Filter(
must=[
models.FieldCondition(key="course_name", match=models.MatchValue(value=course_name)),
models.FieldCondition(key=key,
match=value)])

payload = {
"doc_groups": [group["name"] for group in document["doc_groups"]] + [doc_group_name],
}

# print(f"Updating to Payload: {payload}")

update_operations.append(models.SetPayloadOperation(
set_payload=models.SetPayload(
payload=payload,
filter=searchFilter
),
))

print(f"update_operations for qdrant: {len(update_operations)}")
result = self.qdrant_client.batch_update_points(
collection_name=os.environ['QDRANT_COLLECTION_NAME'],
update_operations=update_operations, wait=False)
return result
11 changes: 10 additions & 1 deletion ai_ta_backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,16 @@ def getTopContextsWithMQR(service: RetrievalService, posthog_service: PosthogSer
response.headers.add('Access-Control-Allow-Origin', '*')
return response


@app.route('/insert_document_groups', methods=['POST'])
def insert_document_groups(service: RetrievalService) -> Response:
data = request.get_json()
csv_path: str = data.get('csv_path', '')
course_name: str = data.get('course_name', '')
doc_group_count, docs_doc_group_count_sql, docs_doc_group_count_vdb = service.insertDocumentGroups(course_name, csv_path)

return jsonify({"message": "Document groups and documents inserted successfully.", "doc_group_count": doc_group_count, "docs_doc_group_count_sql": docs_doc_group_count_sql, "docs_doc_group_count_vdb": docs_doc_group_count_vdb})

@app.route('/getworkflows', methods=['GET'])
def get_all_workflows(service: WorkflowService) -> Response:
"""
Expand Down Expand Up @@ -471,7 +481,6 @@ def run_flow(service: WorkflowService) -> Response:
else:
abort(400, description=f"Bad request: {e}")


def configure(binder: Binder) -> None:
binder.bind(RetrievalService, to=RetrievalService, scope=RequestScope)
binder.bind(PosthogService, to=PosthogService, scope=SingletonScope)
Expand Down
87 changes: 87 additions & 0 deletions ai_ta_backend/service/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Dict, List, Union

import openai
import csv
from injector import inject
from langchain.chat_models import AzureChatOpenAI
from langchain.embeddings.openai import OpenAIEmbeddings
Expand Down Expand Up @@ -466,3 +467,89 @@ def format_for_json(self, found_docs: List[Document]) -> List[Dict]:
]

return contexts

def insertDocumentGroups(self, course_name, csv_path):
"""
Inserts document groups and documents into the database based on the CSV mapping.
Additionally, saves doc group name along with the entire document info in the output csv file.
"""
doc_group_count, docs_doc_groups_count_sql, docs_doc_groups_count_vdb = 0, 0, 0
try:
with open(csv_path, newline='') as csvfile:
for row in csv.DictReader(csvfile):
urls = eval(row['start_urls'])
doc_group_name = row['university']
document_group = {'name': doc_group_name, 'course_name': course_name}
document_group["id"] = self._ingest_document_group(document_group)
if document_group["id"]:
doc_group_count += 1
sql_count, vdb_count = self._process_documents_in_group(urls, course_name, document_group)
docs_doc_groups_count_sql += sql_count
docs_doc_groups_count_vdb += vdb_count
except FileNotFoundError as e:
print(f"CSV file not found: {str(e)}")
except Exception as e:
print(f"An error occurred while processing the CSV file: {str(e)}")

return doc_group_count, docs_doc_groups_count_sql, docs_doc_groups_count_vdb

def _ingest_document_group(self, document_group):
print(f"Inserting document group: {document_group['name']}")
try:
doc_group_id = self.sqlDb.insertDocumentGroupsBulk(document_group)
return doc_group_id
except Exception as e:
print(f"Failed to insert document group {document_group['name']} due to: {str(e)}")
return None

def _process_documents_in_group(self, urls, course_name, document_group):
page, docs_doc_group_count_sql, docs_doc_group_count_vdb = 1, 0, 0
while True:
print(f"fetching page: {page} for doc_group: {document_group['name']}, with urls: {urls}")
existing_documents, updated_points = self._fetch_and_update_documents(urls, course_name, page, document_group)
if not existing_documents:
break
docs_doc_group_count_sql += len(existing_documents)
docs_doc_group_count_vdb += len(updated_points if updated_points else [])
page += 1
return docs_doc_group_count_sql, docs_doc_group_count_vdb

def _fetch_and_update_documents(self, urls, course_name, page, document_group):
try:
existing_documents_response = self.sqlDb.fetchDocumentsByURLs(urls, course_name, page)
existing_documents = existing_documents_response.data
print(f"Existing documents for page {page} and doc_group {document_group['name']}: {len(existing_documents)}")
if not existing_documents:
return None, None
documents_to_update = [doc["id"] for doc in existing_documents]

print(f"Updating documents for page {page} and doc_group {document_group['name']} in SQL")
self.sqlDb.updateDocumentsDocGroupsBulk(documents_to_update, document_group["id"])

print(f"Updating documents for page {page} and doc_group {document_group['name']} in VDB")
updated_points = self.vdb.add_document_groups_to_documents(course_name, existing_documents, document_group['name'])

print(f"Writing updated documents to CSV for doc_group {document_group['name']} and page {page}")

self._write_to_csv('result/updated_documents_sql.csv', existing_documents, document_group)
self._write_to_csv('result/updated_documents_vdb.csv', updated_points, document_group, is_vdb=True)
return existing_documents, updated_points
except Exception as e:
print(f"Failed to fetch/update documents for page {page} and doc_group {document_group['name']} due to: {str(e)}")
return None, None

def _write_to_csv(self, file_path, data, document_group, is_vdb=False):
try:
print(f"Writing updated documents to CSV for {document_group['name']}, is_vdb: {is_vdb}")
with open(file_path, newline='', mode='a') as output_csvfile:
writer = csv.writer(output_csvfile)
if is_vdb:
writer.writerow(['doc_group_name', 'operation results'])
for doc in data:
writer.writerow([document_group['name']] + list(doc))
else:
writer.writerow(['doc_group_name', 'id', 'readable_filename', 'url', 's3_path', 'base_url'])
for doc in data:
writer.writerow([document_group['name']] + list(doc.values()))
except Exception as e:
print(f"Failed to write updated documents to CSV for {document_group['name']} due to: {str(e)}")