Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate all Qdrant to backend: /addDocumentsToDocGroups #10

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions ai_ta_backend/database/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,37 @@ def vector_search(self, search_query, course_name, doc_groups: List[str], user_q

return search_results

def add_documents_to_doc_groups(self, course_name: str, doc: dict):
"""
Update doc_groups for existing documents in the vector database.

Args:
course_name (str): Name of the course
doc (dict): Document object containing url, s3_path, and doc_groups

Returns:
Response from Qdrant set_payload operation
"""
# Build search conditions
must_conditions = [models.FieldCondition(key='course_name', match=models.MatchValue(value=course_name))]

# Add URL condition if present
if doc.get('url'):
must_conditions.append(models.FieldCondition(key='url', match=models.MatchValue(value=doc['url'])))

# Add S3 path condition
must_conditions.append(models.FieldCondition(key='s3_path', match=models.MatchValue(value=doc.get('s3_path', ''))))

# Create the search filter
search_filter = models.Filter(must=must_conditions)

# Update the payload with new doc_groups
response = self.qdrant_client.set_payload(collection_name=os.environ['QDRANT_COLLECTION_NAME'],
payload={'doc_groups': doc['doc_groups']},
points_filter=search_filter)

return response

def _create_search_conditions(self, course_name, doc_groups: List[str]):
"""
Create search conditions for the vector search.
Expand Down
40 changes: 27 additions & 13 deletions ai_ta_backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,14 @@
from ai_ta_backend.executors.thread_pool_executor import \
ThreadPoolExecutorInterface
from ai_ta_backend.extensions import db
from ai_ta_backend.redis_queue.ingestQueue import addJobToIngestQueue
from ai_ta_backend.service.export_service import ExportService
#from ai_ta_backend.service.nomic_service import NomicService
from ai_ta_backend.service.posthog_service import PosthogService
from ai_ta_backend.service.retrieval_service import RetrievalService
from ai_ta_backend.service.sentry_service import SentryService
from ai_ta_backend.service.workflow_service import WorkflowService

from ai_ta_backend.redis_queue.ingestQueue import addJobToIngestQueue


app = Flask(__name__)
CORS(app)
executor = Executor(app)
Expand Down Expand Up @@ -184,6 +182,29 @@ def delete(service: RetrievalService, flaskExecutor: ExecutorInterface):
return response


@app.route('/addDocumentsToDocGroups', methods=['POST'])
def addDocumentsToDocGroups(vector: VectorDatabase):
"""
Add documents to document groups in the Qdrant vector database.
"""
course_name: str = request.args.get('course_name', default='', type=str)
doc_data = request.get_json()

if course_name == '' or not doc_data:
# proper web error "400 Bad request"
abort(400, description=f"Missing required parameters: 'course_name' and request body must be provided. Course name: `{course_name}`")

start_time = time.monotonic()

# Execute synchronously
vector.add_documents_to_doc_groups(course_name, doc_data)
logging.info(f"⏰ Runtime of add doc groups for course: {course_name}: {(time.monotonic() - start_time):.2f} seconds")

response = jsonify({"outcome": 'success'})
response.headers.add('Access-Control-Allow-Origin', '*')
return response


# @app.route('/getNomicMap', methods=['GET'])
# def nomic_map(service: NomicService):
# course_name: str = request.args.get('course_name', default='', type=str)
Expand All @@ -200,7 +221,6 @@ def delete(service: RetrievalService, flaskExecutor: ExecutorInterface):
# response.headers.add('Access-Control-Allow-Origin', '*')
# return response


# @app.route('/createDocumentMap', methods=['GET'])
# def createDocumentMap(service: NomicService):
# course_name: str = request.args.get('course_name', default='', type=str)
Expand All @@ -215,7 +235,6 @@ def delete(service: RetrievalService, flaskExecutor: ExecutorInterface):
# response.headers.add('Access-Control-Allow-Origin', '*')
# return response


# @app.route('/createConversationMap', methods=['GET'])
# def createConversationMap(service: NomicService):
# course_name: str = request.args.get('course_name', default='', type=str)
Expand All @@ -230,7 +249,6 @@ def delete(service: RetrievalService, flaskExecutor: ExecutorInterface):
# response.headers.add('Access-Control-Allow-Origin', '*')
# return response


# @app.route('/logToConversationMap', methods=['GET'])
# def logToConversationMap(service: NomicService, flaskExecutor: ExecutorInterface):
# course_name: str = request.args.get('course_name', default='', type=str)
Expand All @@ -246,7 +264,6 @@ def delete(service: RetrievalService, flaskExecutor: ExecutorInterface):
# response.headers.add('Access-Control-Allow-Origin', '*')
# return response


# @app.route('/onResponseCompletion', methods=['POST'])
# def logToNomic(service: NomicService, flaskExecutor: ExecutorInterface):
# data = request.get_json()
Expand Down Expand Up @@ -481,6 +498,7 @@ def run_flow(service: WorkflowService) -> Response:
else:
abort(400, description=f"Bad request: {e}")


@app.route('/ingest', methods=['POST'])
def ingest() -> Response:
logging.info("In /ingest")
Expand All @@ -491,12 +509,7 @@ def ingest() -> Response:
result = addJobToIngestQueue(data)
logging.info("Result from addJobToIngestQueue: %s", result)

response = jsonify(
{
"outcome": f'Queued Ingest task',
"ingest_task_id": result
}
)
response = jsonify({"outcome": f'Queued Ingest task', "ingest_task_id": result})
response.headers.add('Access-Control-Allow-Origin', '*')
return response

Expand Down Expand Up @@ -590,6 +603,7 @@ def configure(binder: Binder) -> None:
binder.bind(ProcessPoolExecutorInterface, to=ProcessPoolExecutorAdapter, scope=SingletonScope)
logging.info("Configured all services and adapters", binder._bindings)


FlaskInjector(app=app, modules=[configure])

if __name__ == '__main__':
Expand Down