diff --git a/ai_ta_backend/database/sql.py b/ai_ta_backend/database/sql.py index 6f7ae01d..d8fece24 100644 --- a/ai_ta_backend/database/sql.py +++ b/ai_ta_backend/database/sql.py @@ -1,4 +1,5 @@ import os +from typing import Dict, List import supabase from injector import inject @@ -123,3 +124,25 @@ def getConversation(self, course_name: str, key: str, value: str): return self.supabase_client.table("llm-convo-monitor").select("*").eq(key, value).eq("course_name", course_name).execute() + def fetchDocumentsByURLs(self, urls: List[str], course_name: str, page: int = 1, items_per_page: int = 1500): + """ + Fetch documents that have base_url matching any of the URLs in the provided list. + """ + return self.supabase_client.table("documents").select("id, readable_filename, url, s3_path, base_url, doc_groups(name)").in_("base_url", urls).eq("course_name", course_name).range((page - 1) * items_per_page, page * items_per_page - 1).execute() + + def insertDocumentGroupsBulk(self, document_group): + # Assuming the Supabase client's insert method supports returning inserted records + inserted_records = self.supabase_client.table("doc_groups").upsert(document_group, on_conflict="name, course_name", ignore_duplicates=False).execute() + print(f"Inserted document groups: {inserted_records.data}") + # Extract and return the IDs of the inserted document groups + inserted_ids = inserted_records.data[0]['id'] + return inserted_ids + + def updateDocumentsDocGroupsBulk(self, document_ids: List[int], doc_group_id: int): + # Prepare updates + updates = [{"document_id": doc_id, "doc_group_id": doc_group_id} for doc_id in document_ids] + # Perform bulk update + self.supabase_client.table("documents_doc_groups").upsert(updates,on_conflict="document_id, doc_group_id", ignore_duplicates=True).execute() + + + diff --git a/ai_ta_backend/database/vector.py b/ai_ta_backend/database/vector.py index f9d002ec..078a4f1c 100644 --- a/ai_ta_backend/database/vector.py +++ b/ai_ta_backend/database/vector.py @@ -5,6 +5,7 @@ from langchain.embeddings.openai import OpenAIEmbeddings from langchain.vectorstores import Qdrant from qdrant_client import QdrantClient, models +from qdrant_client.conversions.common_types import WriteOrdering OPENAI_API_TYPE = "azure" # "openai" or "azure" @@ -86,3 +87,38 @@ def delete_data(self, collection_name: str, key: str, value: str): ), ]), ) + + def add_document_groups_to_documents(self, course_name: str, documents: List[dict], doc_group_name: str): + """ + Add document groups to documents in the vector database. + """ + print(f"Adding document groups: {doc_group_name} to documents in the vector database for course: {course_name} and {len(documents)} documents.") + update_operations = [] + for document in documents: + # print(f"Adding document groups to document: {document} ") + key = "url" if "url" in document else "s3_path" + value = models.MatchValue(value=document[key]) + searchFilter = models.Filter( + must=[ + models.FieldCondition(key="course_name", match=models.MatchValue(value=course_name)), + models.FieldCondition(key=key, + match=value)]) + + payload = { + "doc_groups": [group["name"] for group in document["doc_groups"]] + [doc_group_name], + } + + # print(f"Updating to Payload: {payload}") + + update_operations.append(models.SetPayloadOperation( + set_payload=models.SetPayload( + payload=payload, + filter=searchFilter + ), + )) + + print(f"update_operations for qdrant: {len(update_operations)}") + result = self.qdrant_client.batch_update_points( + collection_name=os.environ['QDRANT_COLLECTION_NAME'], + update_operations=update_operations, wait=False) + return result \ No newline at end of file diff --git a/ai_ta_backend/main.py b/ai_ta_backend/main.py index 0085abd5..ae0430ff 100644 --- a/ai_ta_backend/main.py +++ b/ai_ta_backend/main.py @@ -380,6 +380,16 @@ def getTopContextsWithMQR(service: RetrievalService, posthog_service: PosthogSer response.headers.add('Access-Control-Allow-Origin', '*') return response + +@app.route('/insert_document_groups', methods=['POST']) +def insert_document_groups(service: RetrievalService) -> Response: + data = request.get_json() + csv_path: str = data.get('csv_path', '') + course_name: str = data.get('course_name', '') + doc_group_count, docs_doc_group_count_sql, docs_doc_group_count_vdb = service.insertDocumentGroups(course_name, csv_path) + + return jsonify({"message": "Document groups and documents inserted successfully.", "doc_group_count": doc_group_count, "docs_doc_group_count_sql": docs_doc_group_count_sql, "docs_doc_group_count_vdb": docs_doc_group_count_vdb}) + @app.route('/getworkflows', methods=['GET']) def get_all_workflows(service: WorkflowService) -> Response: """ @@ -471,7 +481,6 @@ def run_flow(service: WorkflowService) -> Response: else: abort(400, description=f"Bad request: {e}") - def configure(binder: Binder) -> None: binder.bind(RetrievalService, to=RetrievalService, scope=RequestScope) binder.bind(PosthogService, to=PosthogService, scope=SingletonScope) diff --git a/ai_ta_backend/service/retrieval_service.py b/ai_ta_backend/service/retrieval_service.py index c53bcefb..2eb81c56 100644 --- a/ai_ta_backend/service/retrieval_service.py +++ b/ai_ta_backend/service/retrieval_service.py @@ -5,6 +5,7 @@ from typing import Dict, List, Union import openai +import csv from injector import inject from langchain.chat_models import AzureChatOpenAI from langchain.embeddings.openai import OpenAIEmbeddings @@ -466,3 +467,89 @@ def format_for_json(self, found_docs: List[Document]) -> List[Dict]: ] return contexts + + def insertDocumentGroups(self, course_name, csv_path): + """ + Inserts document groups and documents into the database based on the CSV mapping. + Additionally, saves doc group name along with the entire document info in the output csv file. + """ + doc_group_count, docs_doc_groups_count_sql, docs_doc_groups_count_vdb = 0, 0, 0 + try: + with open(csv_path, newline='') as csvfile: + for row in csv.DictReader(csvfile): + urls = eval(row['start_urls']) + doc_group_name = row['university'] + document_group = {'name': doc_group_name, 'course_name': course_name} + document_group["id"] = self._ingest_document_group(document_group) + if document_group["id"]: + doc_group_count += 1 + sql_count, vdb_count = self._process_documents_in_group(urls, course_name, document_group) + docs_doc_groups_count_sql += sql_count + docs_doc_groups_count_vdb += vdb_count + except FileNotFoundError as e: + print(f"CSV file not found: {str(e)}") + except Exception as e: + print(f"An error occurred while processing the CSV file: {str(e)}") + + return doc_group_count, docs_doc_groups_count_sql, docs_doc_groups_count_vdb + + def _ingest_document_group(self, document_group): + print(f"Inserting document group: {document_group['name']}") + try: + doc_group_id = self.sqlDb.insertDocumentGroupsBulk(document_group) + return doc_group_id + except Exception as e: + print(f"Failed to insert document group {document_group['name']} due to: {str(e)}") + return None + + def _process_documents_in_group(self, urls, course_name, document_group): + page, docs_doc_group_count_sql, docs_doc_group_count_vdb = 1, 0, 0 + while True: + print(f"fetching page: {page} for doc_group: {document_group['name']}, with urls: {urls}") + existing_documents, updated_points = self._fetch_and_update_documents(urls, course_name, page, document_group) + if not existing_documents: + break + docs_doc_group_count_sql += len(existing_documents) + docs_doc_group_count_vdb += len(updated_points if updated_points else []) + page += 1 + return docs_doc_group_count_sql, docs_doc_group_count_vdb + + def _fetch_and_update_documents(self, urls, course_name, page, document_group): + try: + existing_documents_response = self.sqlDb.fetchDocumentsByURLs(urls, course_name, page) + existing_documents = existing_documents_response.data + print(f"Existing documents for page {page} and doc_group {document_group['name']}: {len(existing_documents)}") + if not existing_documents: + return None, None + documents_to_update = [doc["id"] for doc in existing_documents] + + print(f"Updating documents for page {page} and doc_group {document_group['name']} in SQL") + self.sqlDb.updateDocumentsDocGroupsBulk(documents_to_update, document_group["id"]) + + print(f"Updating documents for page {page} and doc_group {document_group['name']} in VDB") + updated_points = self.vdb.add_document_groups_to_documents(course_name, existing_documents, document_group['name']) + + print(f"Writing updated documents to CSV for doc_group {document_group['name']} and page {page}") + + self._write_to_csv('result/updated_documents_sql.csv', existing_documents, document_group) + self._write_to_csv('result/updated_documents_vdb.csv', updated_points, document_group, is_vdb=True) + return existing_documents, updated_points + except Exception as e: + print(f"Failed to fetch/update documents for page {page} and doc_group {document_group['name']} due to: {str(e)}") + return None, None + + def _write_to_csv(self, file_path, data, document_group, is_vdb=False): + try: + print(f"Writing updated documents to CSV for {document_group['name']}, is_vdb: {is_vdb}") + with open(file_path, newline='', mode='a') as output_csvfile: + writer = csv.writer(output_csvfile) + if is_vdb: + writer.writerow(['doc_group_name', 'operation results']) + for doc in data: + writer.writerow([document_group['name']] + list(doc)) + else: + writer.writerow(['doc_group_name', 'id', 'readable_filename', 'url', 's3_path', 'base_url']) + for doc in data: + writer.writerow([document_group['name']] + list(doc.values())) + except Exception as e: + print(f"Failed to write updated documents to CSV for {document_group['name']} due to: {str(e)}")