Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 174 additions & 2 deletions api/endpoints/answerQuestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
"""

import os
import json
import logging
import traceback

from pydantic import BaseModel, Field
from typing import Dict, List, Literal

from fastapi.responses import JSONResponse
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.encoders import jsonable_encoder
from fastapi import APIRouter, Depends, HTTPException, Query

Expand Down Expand Up @@ -250,4 +251,175 @@ async def process_question(request_data: answerQuestionRequest, auth: str):
response['llm_provider'] = request_data.llm_provider
response['llm_model'] = request_data.llm_model

return JSONResponse(content=jsonable_encoder(response), media_type='application/json')
return JSONResponse(content=jsonable_encoder(response), media_type='application/json')


@router.get( # SSE response
'/answerQuestion/stream',
response_class = StreamingResponse,
tags = ['Ask a Question']
)
@handle_endpoint_error("answerQuestion")
async def answer_question_stream(
request: answerQuestionRequest = Query(),
auth: str = Depends(authenticate),
):
"""This endpoint processes a natural language question using SSE (Server-Sent Events):

- Streams real-time progress updates
- Searches for relevant tables using vector search
- Determines whether the question should be answered using a SQL query or a metadata search
- Generates a VQL query using an LLM if the question should be answered using a SQL query
- Executes the VQL query and gets the data
- Generates an answer to the question using the data and the VQL query

This endpoint will also automatically look for the the following values in the environment variables for convenience:

- EMBEDDINGS_PROVIDER
- EMBEDDINGS_MODEL
- VECTOR_STORE
- LLM_PROVIDER
- LLM_MODEL
- LLM_TEMPERATURE
- LLM_MAX_TOKENS
- CUSTOM_INSTRUCTIONS
- VQL_EXECUTE_ROWS_LIMIT
- LLM_RESPONSE_ROWS_LIMIT

You can also override the LLM temperature and max_tokens via API parameters for fine-tuning the model behavior."""
return StreamingResponse(
sse_process_question(request, auth),
media_type="text/event-stream"
)


async def sse_process_question(request_data: answerQuestionRequest, auth: str):
"""Generator function to process the question and stream progress updates via SSE"""
# Generate session ID for Langfuse debugging purposes
session_id = generate_session_id(request_data.question)

try:
llm = state_manager.get_llm(
provider_name=request_data.llm_provider,
model_name=request_data.llm_model,
temperature=request_data.llm_temperature,
max_tokens=request_data.llm_max_tokens
)

vector_store = state_manager.get_vector_store(
provider=request_data.vector_store_provider,
embeddings_provider=request_data.embeddings_provider,
embeddings_model=request_data.embeddings_model
)
sample_data_vector_store = state_manager.get_vector_store(
provider=request_data.vector_store_provider,
embeddings_provider=request_data.embeddings_provider,
embeddings_model=request_data.embeddings_model,
index_name="ai_sdk_sample_data"
)
except Exception as e:
logging.error(f"Resource initialization error: {str(e)}")
logging.error(f"Resource initialization traceback: {traceback.format_exc()}")
error_data = {"error": f"Error initializing resources: {str(e)}"}
yield f"data: {json.dumps(error_data)}\n\n"
return

vector_search_tables, sample_data, timings = await sdk_ai_tools.get_relevant_tables(
query=request_data.question,
vector_store=vector_store,
sample_data_vector_store=sample_data_vector_store,
vdb_list=request_data.vdp_database_names,
tag_list=request_data.vdp_tag_names,
auth=auth,
k=request_data.vector_search_k,
use_views=request_data.use_views,
expand_set_views=request_data.expand_set_views,
vector_search_sample_data_k=request_data.vector_search_sample_data_k,
allow_external_associations=request_data.allow_external_associations
)

if not vector_search_tables:
error_data = {"error": "The vector search result returned 0 views. This could be due to limited permissions or an empty vector store."}
yield f"data: {json.dumps(error_data)}\n\n"
return

# Send vector search completion message
tables_list = [table.get('database_name', '') + '.' + table.get('view_name', '') for table in vector_search_tables]
vector_progress_data = {
"status": "vector_search_completed",
"result": {
"answer": tables_list,
"message": "Vector search completed",
"tables_found": len(vector_search_tables)
}
}

yield f"data: {json.dumps(vector_progress_data)}\n\n"

# Combine custom instructions from environment and request
base_instructions = os.getenv('CUSTOM_INSTRUCTIONS', '')
if request_data.custom_instructions:
request_data.custom_instructions = f"{base_instructions}\n{request_data.custom_instructions}".strip()
else:
request_data.custom_instructions = base_instructions

with timing_context("llm_time", timings):
category, category_response, category_related_questions, sql_category_tokens = await sdk_ai_tools.sql_category(
query=request_data.question,
vector_search_tables=vector_search_tables,
llm=llm,
mode=request_data.mode,
custom_instructions=request_data.custom_instructions,
session_id=session_id
)

if category == "SQL":
async for chunk in sdk_answer_question.async_gen_process_sql_category(
request=request_data,
vector_search_tables=vector_search_tables,
category_response=category_response,
auth=auth,
timings=timings,
session_id=session_id,
sample_data=sample_data,
chat_llm=llm,
sql_gen_llm=llm
):
event = chunk.get("type")
c_data = chunk.get("data")

if event == "query_gen":
# print(c_data)
vql_progress_data = {
"status": "vql_generation_completed",
"result": {
"answer": c_data,
"message": "VQL generation completed"
}
}
yield f"data: {json.dumps(vql_progress_data)}\n\n"
elif event == "result":
response = c_data

response['tokens'] = add_tokens(response['tokens'], sql_category_tokens)
elif category == "METADATA":
response = sdk_answer_question.process_metadata_category(
category_response=category_response,
category_related_questions=category_related_questions,
vector_search_tables=vector_search_tables,
timings=timings,
tokens=sql_category_tokens,
disclaimer=request_data.disclaimer
)
else:
response = sdk_answer_question.process_unknown_category(timings=timings)

response['llm_provider'] = request_data.llm_provider
response['llm_model'] = request_data.llm_model

# Send final completion message
final_data = {
"status": "completed",
"result": jsonable_encoder(response)
}
yield f"data: {json.dumps(final_data)}\n\n"
137 changes: 136 additions & 1 deletion api/utils/sdk_answer_question.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,4 +373,139 @@ async def enhance_verbose_response(
# Update timings with the latest llm_time and total_execution_time
response['llm_time'] = timings.get("llm_time", 0)
response['total_execution_time'] = round(sum(timings.values()), 2)
return response
return response


# async generator
async def async_gen_process_sql_category(request, vector_search_tables, sql_gen_llm, chat_llm, category_response, auth, timings, session_id = None, sample_data = None):
with timing_context("llm_time", timings):
vql_query, query_explanation, query_to_vql_tokens = await sdk_ai_tools.query_to_vql(
query=request.question,
vector_search_tables=vector_search_tables,
llm=sql_gen_llm,
filter_params=category_response,
custom_instructions=request.custom_instructions,
vector_search_sample_data_k=request.vector_search_sample_data_k,
session_id=session_id,
sample_data=sample_data
)

# Early exit if no valid VQL could be generated
if not vql_query:
response = prepare_response(
vql_query='',
query_explanation=query_explanation,
tokens=query_to_vql_tokens,
execution_result={},
vector_search_tables=vector_search_tables,
raw_graph='',
timings=timings
)
response['answer'] = 'No VQL query was generated because no relevant schema was found.'
yield {"type": "result", "data": response}


vql_query, _, query_fixer_tokens = await sdk_ai_tools.query_fixer(
question=request.question,
query=vql_query,
query_explanation=query_explanation,
llm=sql_gen_llm,
session_id=session_id,
vector_search_sample_data_k=request.vector_search_sample_data_k,
vector_search_tables=vector_search_tables,
sample_data=sample_data
)


max_attempts = 2
attempt = 0
fixer_history = []
original_vql_query = vql_query

# return chunk ( generated vql )
yield {"type": "query_gen", "data": original_vql_query}

while attempt < max_attempts:
vql_query, execution_result, vql_status_code, timings, fixer_history, query_fixer_tokens = await attempt_query_execution(
vql_query=vql_query,
request=request,
auth=auth,
timings=timings,
vector_search_tables=vector_search_tables,
session_id=session_id,
query_explanation=query_explanation,
query_fixer_tokens=query_fixer_tokens,
fixer_history=fixer_history,
sample_data=sample_data,
llm=sql_gen_llm
)

if attempt == 0:
original_execution_result = execution_result
original_vql_status_code = vql_status_code

if vql_query == 'OK':
vql_query = original_vql_query
break
elif vql_status_code not in [499, 500]:
break

attempt += 1

if vql_status_code in [499, 500]:
if vql_query:
execution_result, vql_status_code, timings = await execute_query(
vql_query=vql_query,
auth=auth,
limit=request.vql_execute_rows_limit,
timings=timings
)
if vql_status_code == 500 or (vql_status_code == 499 and original_vql_status_code == 499):
vql_query = original_vql_query
execution_result = original_execution_result
vql_status_code = original_vql_status_code

else:
vql_status_code = 500
execution_result = "No VQL query was generated."

llm_execution_result = prepare_execution_result(
execution_result=execution_result,
llm_response_rows_limit=request.llm_response_rows_limit,
vql_status_code=vql_status_code
)

raw_graph, plot_data, request = handle_plotting(request=request, execution_result=execution_result)

response = prepare_response(
vql_query=vql_query,
query_explanation=query_explanation,
tokens=add_tokens(query_to_vql_tokens, query_fixer_tokens),
execution_result=execution_result if vql_status_code == 200 else {},
vector_search_tables=vector_search_tables,
raw_graph=raw_graph,
timings=timings
)

if request.verbose or request.plot:
response = await enhance_verbose_response(
request=request,
response=response,
vql_query=vql_query,
llm_execution_result=llm_execution_result,
vector_search_tables=vector_search_tables,
plot_data=plot_data,
timings=timings,
session_id=session_id,
sample_data=sample_data,
chat_llm=chat_llm,
sql_gen_llm=sql_gen_llm
)

if request.disclaimer:
response['answer'] += "\n\nDISCLAIMER: This response has been generated based on an LLM's interpretation of the data and may not be accurate."

# return response

# return final chunk
yield {"type": "result", "data": response}
Loading