diff --git a/api/endpoints/answerQuestion.py b/api/endpoints/answerQuestion.py index 3869dc5..6aff2cf 100644 --- a/api/endpoints/answerQuestion.py +++ b/api/endpoints/answerQuestion.py @@ -10,13 +10,14 @@ """ import os +import json import logging import traceback from pydantic import BaseModel, Field from typing import Dict, List, Literal -from fastapi.responses import JSONResponse +from fastapi.responses import JSONResponse, StreamingResponse from fastapi.encoders import jsonable_encoder from fastapi import APIRouter, Depends, HTTPException, Query @@ -250,4 +251,175 @@ async def process_question(request_data: answerQuestionRequest, auth: str): response['llm_provider'] = request_data.llm_provider response['llm_model'] = request_data.llm_model - return JSONResponse(content=jsonable_encoder(response), media_type='application/json') \ No newline at end of file + return JSONResponse(content=jsonable_encoder(response), media_type='application/json') + + +@router.get( # SSE response + '/answerQuestion/stream', + response_class = StreamingResponse, + tags = ['Ask a Question'] +) +@handle_endpoint_error("answerQuestion") +async def answer_question_stream( + request: answerQuestionRequest = Query(), + auth: str = Depends(authenticate), +): + """This endpoint processes a natural language question using SSE (Server-Sent Events): + + - Streams real-time progress updates + - Searches for relevant tables using vector search + - Determines whether the question should be answered using a SQL query or a metadata search + - Generates a VQL query using an LLM if the question should be answered using a SQL query + - Executes the VQL query and gets the data + - Generates an answer to the question using the data and the VQL query + + This endpoint will also automatically look for the the following values in the environment variables for convenience: + + - EMBEDDINGS_PROVIDER + - EMBEDDINGS_MODEL + - VECTOR_STORE + - LLM_PROVIDER + - LLM_MODEL + - LLM_TEMPERATURE + - LLM_MAX_TOKENS + - CUSTOM_INSTRUCTIONS + - VQL_EXECUTE_ROWS_LIMIT + - LLM_RESPONSE_ROWS_LIMIT + + You can also override the LLM temperature and max_tokens via API parameters for fine-tuning the model behavior.""" + return StreamingResponse( + sse_process_question(request, auth), + media_type="text/event-stream" + ) + + +async def sse_process_question(request_data: answerQuestionRequest, auth: str): + """Generator function to process the question and stream progress updates via SSE""" + # Generate session ID for Langfuse debugging purposes + session_id = generate_session_id(request_data.question) + + try: + llm = state_manager.get_llm( + provider_name=request_data.llm_provider, + model_name=request_data.llm_model, + temperature=request_data.llm_temperature, + max_tokens=request_data.llm_max_tokens + ) + + vector_store = state_manager.get_vector_store( + provider=request_data.vector_store_provider, + embeddings_provider=request_data.embeddings_provider, + embeddings_model=request_data.embeddings_model + ) + sample_data_vector_store = state_manager.get_vector_store( + provider=request_data.vector_store_provider, + embeddings_provider=request_data.embeddings_provider, + embeddings_model=request_data.embeddings_model, + index_name="ai_sdk_sample_data" + ) + except Exception as e: + logging.error(f"Resource initialization error: {str(e)}") + logging.error(f"Resource initialization traceback: {traceback.format_exc()}") + error_data = {"error": f"Error initializing resources: {str(e)}"} + yield f"data: {json.dumps(error_data)}\n\n" + return + + vector_search_tables, sample_data, timings = await sdk_ai_tools.get_relevant_tables( + query=request_data.question, + vector_store=vector_store, + sample_data_vector_store=sample_data_vector_store, + vdb_list=request_data.vdp_database_names, + tag_list=request_data.vdp_tag_names, + auth=auth, + k=request_data.vector_search_k, + use_views=request_data.use_views, + expand_set_views=request_data.expand_set_views, + vector_search_sample_data_k=request_data.vector_search_sample_data_k, + allow_external_associations=request_data.allow_external_associations + ) + + if not vector_search_tables: + error_data = {"error": "The vector search result returned 0 views. This could be due to limited permissions or an empty vector store."} + yield f"data: {json.dumps(error_data)}\n\n" + return + + # Send vector search completion message + tables_list = [table.get('database_name', '') + '.' + table.get('view_name', '') for table in vector_search_tables] + vector_progress_data = { + "status": "vector_search_completed", + "result": { + "answer": tables_list, + "message": "Vector search completed", + "tables_found": len(vector_search_tables) + } + } + + yield f"data: {json.dumps(vector_progress_data)}\n\n" + + # Combine custom instructions from environment and request + base_instructions = os.getenv('CUSTOM_INSTRUCTIONS', '') + if request_data.custom_instructions: + request_data.custom_instructions = f"{base_instructions}\n{request_data.custom_instructions}".strip() + else: + request_data.custom_instructions = base_instructions + + with timing_context("llm_time", timings): + category, category_response, category_related_questions, sql_category_tokens = await sdk_ai_tools.sql_category( + query=request_data.question, + vector_search_tables=vector_search_tables, + llm=llm, + mode=request_data.mode, + custom_instructions=request_data.custom_instructions, + session_id=session_id + ) + + if category == "SQL": + async for chunk in sdk_answer_question.async_gen_process_sql_category( + request=request_data, + vector_search_tables=vector_search_tables, + category_response=category_response, + auth=auth, + timings=timings, + session_id=session_id, + sample_data=sample_data, + chat_llm=llm, + sql_gen_llm=llm + ): + event = chunk.get("type") + c_data = chunk.get("data") + + if event == "query_gen": + # print(c_data) + vql_progress_data = { + "status": "vql_generation_completed", + "result": { + "answer": c_data, + "message": "VQL generation completed" + } + } + yield f"data: {json.dumps(vql_progress_data)}\n\n" + elif event == "result": + response = c_data + + response['tokens'] = add_tokens(response['tokens'], sql_category_tokens) + elif category == "METADATA": + response = sdk_answer_question.process_metadata_category( + category_response=category_response, + category_related_questions=category_related_questions, + vector_search_tables=vector_search_tables, + timings=timings, + tokens=sql_category_tokens, + disclaimer=request_data.disclaimer + ) + else: + response = sdk_answer_question.process_unknown_category(timings=timings) + + response['llm_provider'] = request_data.llm_provider + response['llm_model'] = request_data.llm_model + + # Send final completion message + final_data = { + "status": "completed", + "result": jsonable_encoder(response) + } + yield f"data: {json.dumps(final_data)}\n\n" \ No newline at end of file diff --git a/api/utils/sdk_answer_question.py b/api/utils/sdk_answer_question.py index f3c25bb..1db3020 100644 --- a/api/utils/sdk_answer_question.py +++ b/api/utils/sdk_answer_question.py @@ -373,4 +373,139 @@ async def enhance_verbose_response( # Update timings with the latest llm_time and total_execution_time response['llm_time'] = timings.get("llm_time", 0) response['total_execution_time'] = round(sum(timings.values()), 2) - return response \ No newline at end of file + return response + + +# async generator +async def async_gen_process_sql_category(request, vector_search_tables, sql_gen_llm, chat_llm, category_response, auth, timings, session_id = None, sample_data = None): + with timing_context("llm_time", timings): + vql_query, query_explanation, query_to_vql_tokens = await sdk_ai_tools.query_to_vql( + query=request.question, + vector_search_tables=vector_search_tables, + llm=sql_gen_llm, + filter_params=category_response, + custom_instructions=request.custom_instructions, + vector_search_sample_data_k=request.vector_search_sample_data_k, + session_id=session_id, + sample_data=sample_data + ) + + # Early exit if no valid VQL could be generated + if not vql_query: + response = prepare_response( + vql_query='', + query_explanation=query_explanation, + tokens=query_to_vql_tokens, + execution_result={}, + vector_search_tables=vector_search_tables, + raw_graph='', + timings=timings + ) + response['answer'] = 'No VQL query was generated because no relevant schema was found.' + yield {"type": "result", "data": response} + + + vql_query, _, query_fixer_tokens = await sdk_ai_tools.query_fixer( + question=request.question, + query=vql_query, + query_explanation=query_explanation, + llm=sql_gen_llm, + session_id=session_id, + vector_search_sample_data_k=request.vector_search_sample_data_k, + vector_search_tables=vector_search_tables, + sample_data=sample_data + ) + + + max_attempts = 2 + attempt = 0 + fixer_history = [] + original_vql_query = vql_query + + # return chunk ( generated vql ) + yield {"type": "query_gen", "data": original_vql_query} + + while attempt < max_attempts: + vql_query, execution_result, vql_status_code, timings, fixer_history, query_fixer_tokens = await attempt_query_execution( + vql_query=vql_query, + request=request, + auth=auth, + timings=timings, + vector_search_tables=vector_search_tables, + session_id=session_id, + query_explanation=query_explanation, + query_fixer_tokens=query_fixer_tokens, + fixer_history=fixer_history, + sample_data=sample_data, + llm=sql_gen_llm + ) + + if attempt == 0: + original_execution_result = execution_result + original_vql_status_code = vql_status_code + + if vql_query == 'OK': + vql_query = original_vql_query + break + elif vql_status_code not in [499, 500]: + break + + attempt += 1 + + if vql_status_code in [499, 500]: + if vql_query: + execution_result, vql_status_code, timings = await execute_query( + vql_query=vql_query, + auth=auth, + limit=request.vql_execute_rows_limit, + timings=timings + ) + if vql_status_code == 500 or (vql_status_code == 499 and original_vql_status_code == 499): + vql_query = original_vql_query + execution_result = original_execution_result + vql_status_code = original_vql_status_code + + else: + vql_status_code = 500 + execution_result = "No VQL query was generated." + + llm_execution_result = prepare_execution_result( + execution_result=execution_result, + llm_response_rows_limit=request.llm_response_rows_limit, + vql_status_code=vql_status_code + ) + + raw_graph, plot_data, request = handle_plotting(request=request, execution_result=execution_result) + + response = prepare_response( + vql_query=vql_query, + query_explanation=query_explanation, + tokens=add_tokens(query_to_vql_tokens, query_fixer_tokens), + execution_result=execution_result if vql_status_code == 200 else {}, + vector_search_tables=vector_search_tables, + raw_graph=raw_graph, + timings=timings + ) + + if request.verbose or request.plot: + response = await enhance_verbose_response( + request=request, + response=response, + vql_query=vql_query, + llm_execution_result=llm_execution_result, + vector_search_tables=vector_search_tables, + plot_data=plot_data, + timings=timings, + session_id=session_id, + sample_data=sample_data, + chat_llm=chat_llm, + sql_gen_llm=sql_gen_llm + ) + + if request.disclaimer: + response['answer'] += "\n\nDISCLAIMER: This response has been generated based on an LLM's interpretation of the data and may not be accurate." + + # return response + + # return final chunk + yield {"type": "result", "data": response}