Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,9 @@ Pulumi.*.yaml
chat-server/tests/chat_server_tests/*.json
chat-server/tests/chat_server_tests/output

# Chat-server Test datasets
chat-server/tests/e2e/datasets
chat-server/tests/e2e/datasets/*.csv

# Visualization test artifacts
chat-server/tests/e2e/viz_utils/examples_arguments_syntax
chat-server/tests/e2e/datasets
chat-server/tests/e2e/output
chat-server/tests/dspy/output
17 changes: 12 additions & 5 deletions chat-server/app/api/v1/routers/dataset_upload.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from typing import Annotated

from fastapi import APIRouter, HTTPException, status
from fastapi import APIRouter, Header, HTTPException, status

from app.core.session import SingletonAiohttp
from app.models.router import UploadResponse, UploadSchemaRequest
Expand All @@ -17,7 +18,10 @@


@dataset_router.post("/upload_schema", response_model=UploadResponse)
async def upload_schema(payload: UploadSchemaRequest):
async def upload_schema(
payload: UploadSchemaRequest,
x_organization_id: Annotated[str | None, Header()] = None,
):
"""
Processes and index dataset schema.

Expand All @@ -28,16 +32,19 @@ async def upload_schema(payload: UploadSchemaRequest):
project_id = payload.project_id
dataset_id = payload.dataset_id
dataset_details, project_details = await asyncio.gather(
get_dataset_info(dataset_id, project_id),
get_project_info(project_id),
get_dataset_info(dataset_id, project_id, org_id=x_organization_id),
get_project_info(project_id, org_id=x_organization_id),
)
dataset_summary, sample_data = await generate_summary(
dataset_details.name, org_id=x_organization_id
)
dataset_summary, sample_data = await generate_summary(dataset_details.name)

success = await store_schema_in_qdrant(
dataset_summary=dataset_summary,
sample_data=sample_data,
dataset_details=dataset_details,
project_details=project_details,
org_id=x_organization_id,
)
if not success:
raise HTTPException(
Expand Down
42 changes: 42 additions & 0 deletions chat-server/app/api/v1/routers/fetch_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from fastapi import APIRouter, HTTPException, status

from app.core.log import custom_logger as logger
from app.models.router import FetchSqlRequest, FetchSqlResponse
from app.workflow.graph.nl_to_sql_graph.graph import nl_to_sql_graph

fetch_sql_router = APIRouter()


@fetch_sql_router.post("/fetch-sql", response_model=FetchSqlResponse)
async def fetch_sql(payload: FetchSqlRequest):
"""
Generate SQL queries from a natural language description.

- Single dataset: provide `dataset_ids` with one ID
- Multi-dataset: provide `dataset_ids` with multiple IDs and/or `project_ids`

Returns a list of generated SQL queries with their explanations.
"""
try:
dataset_ids = [did for did in (payload.dataset_ids or []) if did]
project_ids = [pid for pid in (payload.project_ids or []) if pid]

result = await nl_to_sql_graph.ainvoke(
{
"user_query": payload.description,
"dataset_ids": dataset_ids or None,
"project_ids": project_ids or None,
}
)

return FetchSqlResponse(
sql_queries=result.get("sql_queries", []),
message=result.get("message"),
)
except Exception as e:
logger.error(f"Error in fetch_sql: {e}")

raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to generate SQL queries, please try again.",
) from e
12 changes: 3 additions & 9 deletions chat-server/app/api/v1/routers/health.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
from datetime import datetime
from typing import Any, Dict

import aiohttp
from fastapi import APIRouter
from fastapi.responses import JSONResponse
from qdrant_client import AsyncQdrantClient

from app.core.config import settings
from app.core.session import SingletonAiohttp

router = APIRouter()

Expand Down Expand Up @@ -103,15 +101,11 @@ async def check_gopie_server_health() -> Dict[str, Any]:
return {"status": "not_configured", "error": "GOPIE_API_ENDPOINT not configured"}

try:
http_session = SingletonAiohttp.get_aiohttp_client()
from app.services.gopie.client import GopieClient

client = GopieClient()
# Test basic connectivity to Gopie server
url = settings.GOPIE_API_ENDPOINT.rstrip("/")
headers = {"accept": "application/json"}

async with http_session.get(
url, headers=headers, timeout=aiohttp.ClientTimeout(total=10)
) as response:
async with await client.get("/") as response:
# Any response (even 404) means the server is reachable
return {"status": "healthy", "response_code": response.status, "server_reachable": True}
except Exception as e:
Expand Down
6 changes: 5 additions & 1 deletion chat-server/app/api/v1/routers/query.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import uuid
from typing import Annotated

from fastapi import APIRouter
from fastapi import APIRouter, Header
from fastapi.responses import JSONResponse, StreamingResponse

from app.utils.adapters.openai.input import (
Expand All @@ -22,6 +23,7 @@ async def root():
@router.post("/chat/completions")
async def create(
openai_format_request: RequestNonStreaming | RequestStreaming,
x_organization_id: Annotated[str | None, Header()] = None,
):
"""
Handle chat completion requests, supporting both streaming and non-streaming responses.
Expand Down Expand Up @@ -53,6 +55,7 @@ async def create(
chat_id=chat_id,
dataset_ids=request.dataset_ids,
project_ids=request.project_ids,
org_id=x_organization_id,
)
),
media_type="text/event-stream",
Expand All @@ -65,5 +68,6 @@ async def create(
chat_id=chat_id,
dataset_ids=request.dataset_ids,
project_ids=request.project_ids,
org_id=x_organization_id,
)
)
6 changes: 6 additions & 0 deletions chat-server/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@ class Settings(BaseSettings):

QDRANT_HOST: str = "host.docker.local"
QDRANT_COLLECTION: str = "dataset_collection"
QDRANT_DUCKDB_COLLECTION: str = "duckdb_docs_collection"
QDRANT_PORT: int = 6333
QDRANT_TOP_K: int = 5
QDRANT_DUCKDB_TOP_K: int = 5

GOPIE_API_ENDPOINT: str = ""

Expand Down Expand Up @@ -108,6 +110,10 @@ class Settings(BaseSettings):
COLUMN_TRUNCATION_LIMIT: int = 200
DISPLAY_ROWS_AFTER_TRUNCATION_LIMIT: int = 20

# Dataset Sampling Constants
TARGET_ROWS: int = 150000
SAMPLING_THRESHOLD: int = 150000

model_config = SettingsConfigDict(
env_file=".env", extra="ignore", case_sensitive=True, env_prefix="CHAT_"
)
Expand Down
6 changes: 4 additions & 2 deletions chat-server/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from app.api.v1.routers.dataset_upload import (
dataset_router as schema_upload_router,
)
from app.api.v1.routers.fetch_sql import fetch_sql_router
from app.api.v1.routers.health import router as health_router
from app.api.v1.routers.query import router as query_router
from app.core.config import settings
Expand All @@ -19,9 +20,9 @@

@asynccontextmanager
async def lifespan(app: FastAPI):
await QdrantSetup.get_async_client()
await QdrantSetup.get_async_client(settings.QDRANT_COLLECTION)
QdrantSetup.get_sync_client(settings.QDRANT_COLLECTION)
SingletonAiohttp.get_aiohttp_client()
QdrantSetup.get_sync_client()
get_embedding_provider().get_embeddings_model(settings.DEFAULT_EMBEDDING_MODEL)
visualize_graph()

Expand Down Expand Up @@ -59,6 +60,7 @@ async def add_process_time_header(request: Request, call_next):
app.include_router(health_router, prefix=settings.API_V1_STR, tags=["health"])
app.include_router(query_router, prefix=settings.API_V1_STR, tags=["query"])
app.include_router(schema_upload_router, prefix=settings.API_V1_STR, tags=["upload_schema"])
app.include_router(fetch_sql_router, prefix=settings.API_V1_STR, tags=["fetch_sql"])


def start():
Expand Down
16 changes: 9 additions & 7 deletions chat-server/app/models/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class SubQueryInfo:
sql_queries: list[SqlQueryInfo]
tables_used: list[str] | None = None
error_message: list[dict] | None = None
no_sql_response: str | None = None
non_sql_response: str | None = None
retry_count: int = 0
node_messages: dict = field(default_factory=dict)

Expand All @@ -85,7 +85,7 @@ def to_dict(self) -> dict[str, Any]:
"sql_queries": [query.to_dict() for query in self.sql_queries],
"tables_used": self.tables_used,
"error_message": self.error_message,
"no_sql_response": self.no_sql_response,
"non_sql_response": self.non_sql_response,
"retry_count": self.retry_count,
"node_messages": self.node_messages,
}
Expand All @@ -95,19 +95,21 @@ def to_dict(self) -> dict[str, Any]:
class SingleDatasetQueryResult:
user_friendly_dataset_name: str | None
dataset_name: str | None
sql_results: list[SqlQueryInfo] | None
response_for_non_sql: str | None
sql_queries: list[SqlQueryInfo] | None
non_sql_response: str | None
error: str | None
retry_count: int = 0

def to_dict(self) -> dict[str, Any]:
return {
"user_friendly_dataset_name": self.user_friendly_dataset_name,
"dataset_name": self.dataset_name,
"sql_results": (
[result.to_dict() for result in self.sql_results] if self.sql_results else None
"sql_queries": (
[result.to_dict() for result in self.sql_queries] if self.sql_queries else None
),
"response_for_non_sql": self.response_for_non_sql,
"non_sql_response": self.non_sql_response,
"error": self.error,
"retry_count": self.retry_count,
}


Expand Down
23 changes: 22 additions & 1 deletion chat-server/app/models/router.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from langchain_core.messages import BaseMessage
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator

from app.workflow.graph.nl_to_sql_graph.types import SqlQuery


class UploadResponse(BaseModel):
Expand All @@ -20,3 +22,22 @@ class QueryRequest(BaseModel):
chat_id: str | None = None
trace_id: str | None = None
model_id: str | None = None


class FetchSqlRequest(BaseModel):
project_ids: list[str] | None = Field(default=None, description="List of project IDs")
dataset_ids: list[str] | None = Field(default=None, description="List of dataset IDs")
description: str = Field(..., description="Natural language description of the data to query")


class FetchSqlResponse(BaseModel):
sql_queries: list[SqlQuery] = Field(default_factory=list, description="Generated SQL queries")
message: str | None = Field(
default=None, description="Message when no SQL queries are generated"
)

@model_validator(mode="after")
def validate_response(self) -> "FetchSqlResponse":
if not self.sql_queries and not self.message:
self.message = "Unable to generate SQL queries for the given description."
return self
1 change: 1 addition & 0 deletions chat-server/app/models/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class DatasetSchema(BaseModel):
dataset_custom_prompt: Optional[str] = None
dataset_description: str
project_id: str
org_id: Optional[str] = None
dataset_id: str
columns: list[ColumnSchema]

Expand Down
97 changes: 97 additions & 0 deletions chat-server/app/services/gopie/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
Unified client for all Gopie API requests.
Handles authentication, headers, and request management.
"""
from typing import Any, Optional

import aiohttp

from app.core.config import settings
from app.core.log import custom_logger as logger
from app.core.session import SingletonAiohttp


class GopieClient:
"""
Unified client for making requests to the Gopie API.
Automatically handles org-id header and base URL configuration.
"""

def __init__(self, org_id: Optional[str] = None):
"""
Initialize the Gopie client.

Args:
org_id: Organization ID to include in request headers
"""
self.base_url = settings.GOPIE_API_ENDPOINT.rstrip("/")
self.org_id = org_id
self._session = SingletonAiohttp.get_aiohttp_client()

def _get_headers(self, additional_headers: Optional[dict[str, str]] = None) -> dict[str, str]:
"""
Build headers for the request, including org-id if available.

Args:
additional_headers: Optional additional headers to include

Returns:
Dictionary of headers
"""
headers = {"accept": "application/json"}

if self.org_id:
headers["X-Organization-id"] = self.org_id

if additional_headers:
headers.update(additional_headers)

return headers

async def get(
self,
path: str,
headers: Optional[dict[str, str]] = None,
**kwargs: Any,
) -> aiohttp.ClientResponse:
"""
Make a GET request to the Gopie API.

Args:
path: API path (will be appended to base_url)
headers: Optional additional headers
**kwargs: Additional arguments to pass to aiohttp

Returns:
aiohttp ClientResponse
"""
url = f"{self.base_url}{path}"
request_headers = self._get_headers(headers)

logger.debug(f"GET request to {url}")
return await self._session.get(url, headers=request_headers, **kwargs)

async def post(
self,
path: str,
json: Optional[dict[str, Any]] = None,
headers: Optional[dict[str, str]] = None,
**kwargs: Any,
) -> aiohttp.ClientResponse:
"""
Make a POST request to the Gopie API.

Args:
path: API path (will be appended to base_url)
json: JSON payload for the request
headers: Optional additional headers
**kwargs: Additional arguments to pass to aiohttp

Returns:
aiohttp ClientResponse
"""
url = f"{self.base_url}{path}"
request_headers = self._get_headers(headers)

logger.debug(f"POST request to {url}")
return await self._session.post(url, json=json, headers=request_headers, **kwargs)
Comment thread
paul-tharun marked this conversation as resolved.
Loading