diff --git a/sql_agent/.dockerignore b/sql_agent/.dockerignore deleted file mode 100644 index 872084e0..00000000 --- a/sql_agent/.dockerignore +++ /dev/null @@ -1,5 +0,0 @@ -*test.py -.git -.env -.vscode/* -*.env \ No newline at end of file diff --git a/sql_agent/.env.example b/sql_agent/.env.example deleted file mode 100644 index 1c4db89a..00000000 --- a/sql_agent/.env.example +++ /dev/null @@ -1,15 +0,0 @@ -SLACK_BOT_TOKEN=token -SLACK_APP_TOKEN=token -SHAPE_ANALYTICS_API_KEY=token -OPENAI_API_KEY=token - -DB_TYPE=value - -BQ_API_KEY=token - -SNOWFLAKE_ACCOUNT=value -SNOWFLAKE_USER=value -SNOWFLAKE_PASSWORD=value -SNOWFLAKE_DATABASE=value -SNOWFLAKE_SCHEMA=value -SNOWFLAKE_TABLE=value \ No newline at end of file diff --git a/sql_agent/.gcloudignore b/sql_agent/.gcloudignore deleted file mode 100644 index d6c92689..00000000 --- a/sql_agent/.gcloudignore +++ /dev/null @@ -1,4 +0,0 @@ -*test.py -.git -.env -.vscode/* \ No newline at end of file diff --git a/sql_agent/.prettierrc b/sql_agent/.prettierrc deleted file mode 100644 index 5a938ce1..00000000 --- a/sql_agent/.prettierrc +++ /dev/null @@ -1,4 +0,0 @@ -{ - "tabWidth": 4, - "useTabs": false -} diff --git a/sql_agent/.vscode/launch.json b/sql_agent/.vscode/launch.json deleted file mode 100644 index a005a712..00000000 --- a/sql_agent/.vscode/launch.json +++ /dev/null @@ -1,26 +0,0 @@ -{ - "version":"0.2.0", - "configurations": [ - { - "name": "Python: functions_framework", - "type": "python", - "request": "launch", - "module": "functions_framework", - "justMyCode": false, - "args" : [ - "--target", - "runSQLAgent", - "--debug" - ] - }, - { - "name": "Python: Current File", - "type": "python", - "request": "launch", - "program": "${file}", - "justMyCode": false, - "console": "integratedTerminal" - - } - ] -} \ No newline at end of file diff --git a/sql_agent/Dockerfile b/sql_agent/Dockerfile deleted file mode 100644 index 9eb3ebab..00000000 --- a/sql_agent/Dockerfile +++ /dev/null @@ -1,16 +0,0 @@ -# Use the official Python image. -# https://hub.docker.com/_/python -FROM python:3.11.3-slim - -# Set the working directory to /app -WORKDIR /app - -# Copy the current directory contents into the container at /app -COPY . /app - -# Install any needed packages specified in requirements.txt -RUN pip install --trusted-host pypi.python.org -r requirements.txt - -CMD ["shape_bot_app.py"] - -ENTRYPOINT ["python"] diff --git a/sql_agent/README.md b/sql_agent/README.md deleted file mode 100644 index d4c31394..00000000 --- a/sql_agent/README.md +++ /dev/null @@ -1,39 +0,0 @@ -# SQL Agent - -Takes a user query and runs SQL agent. - -Can be accessed via: - -1. Slackbot -1. HTTP Functions - -Supports - -1. Snowflake -2. Google BigQuery -3. DuckDB - -## Development - -## Env Vars - -```bash -OPENAI_API_KEY=key -DB_TYPE=snowflake|bigquery|duckdb -``` - -### HTTP Functions - -```bash -functions-framework --target runSQLAgent --debug -``` - -### DuckDB Usage - -Load a CSV from disk by setting the DUCKDB_CSV env var. - -This loads the file into an in-memory table like this: - -```sql -CREATE TABLE data AS SELECT FROM '$DUCKDB_CSV_PATH'; -``` diff --git a/sql_agent/database_factory.py b/sql_agent/database_factory.py deleted file mode 100644 index b0aba03e..00000000 --- a/sql_agent/database_factory.py +++ /dev/null @@ -1,67 +0,0 @@ -from enum import Enum -from langchain.sql_database import SQLDatabase -from sqlalchemy.engine import create_engine -import json -from decouple import config -from snowflake.sqlalchemy import URL - - -class DatabaseType(Enum): - BigQuery = "bigquery" - Snowflake = "snowflake" - DuckDB = "duckdb" - - -class DatabaseFactory: - @classmethod - def create_database(cls) -> SQLDatabase: - db_type = DatabaseType(config("DB_TYPE")) - match db_type: - case DatabaseType.BigQuery: - return cls.__create_bigquery_database() - case DatabaseType.Snowflake: - return cls.__create_snowflake_database() - case DatabaseType.DuckDB: - return cls.__create_duck_db_database() - case _: - raise Exception("Database Type not supported") - - @classmethod - def __create_bigquery_database(cls) -> SQLDatabase: - engine = create_engine( - "bigquery://", credentials_info=json.loads(config("BQ_API_KEY")) - ) - return SQLDatabase(engine=engine) - - @classmethod - def __create_duck_db_database(cls) -> SQLDatabase: - engine = create_engine("duckdb:///:memory:") - # you need to create the table before SQLDatabase is initialised - # otherwise SQLAlchemy will say there are no tables - csv_path = config("DUCKDB_CSV_PATH") - if csv_path: - with engine.connect() as con: - con.execute( - f"CREATE TABLE leads AS SELECT * FROM read_csv_auto('{csv_path}')" - ) - else: - print("DUCKDB_CSV_PATH not set, skipping table creation") - - db = SQLDatabase(engine=engine) - return db - - @classmethod - def __create_snowflake_database(cls) -> SQLDatabase: - engine = create_engine( - URL( - account=config("SNOWFLAKE_ACCOUNT"), - user=config("SNOWFLAKE_USER"), - password=config("SNOWFLAKE_PASSWORD"), - database=config("SNOWFLAKE_DATABASE"), - schema=config("SNOWFLAKE_SCHEMA"), - ) - ) - connection = engine.connect() - return SQLDatabase( - engine=engine, include_tables=[config("SNOWFLAKE_TABLE")] - ) # TODO: Create Vector Store & Embeddings diff --git a/sql_agent/main.py b/sql_agent/main.py deleted file mode 100644 index cceb4099..00000000 --- a/sql_agent/main.py +++ /dev/null @@ -1,92 +0,0 @@ -""" Styleguide: https://google.github.io/styleguide/pyguide.html """ -import functions_framework -import threading -from typing import Callable -from flask import Response -from flask import escape -from langchain.callbacks.base import CallbackManager -from langchain.callbacks.stdout import StdOutCallbackHandler -from shape_sql_callback import ShapeSQLCallbackHandler -from shape_sql_callback import create_shape_sql_agent -from sqlalchemy import * -from sqlalchemy.schema import * -from threaded_generator import ThreadedGenerator -from shape_analytics import ShapeAnalytics -from slack_data import SlackData -from snowflake.sqlalchemy import URL -from slack_bolt.context.say import Say -from typing import Optional - - -def agent_thread( - threadedGntr: ThreadedGenerator, - query: str, - shapeAnalytics: ShapeAnalytics, - slackData: Optional[SlackData], -): - try: - agent_executor = create_shape_sql_agent( - callback_manager=CallbackManager( - [ - ShapeSQLCallbackHandler( - threadedGntr=threadedGntr, - shapeAnalytics=shapeAnalytics, - slackData=slackData, - ), - StdOutCallbackHandler(), - ] - ), - ) - agent_executor.run(query) - finally: - shapeAnalytics.track("agent_thread Completed", {"Query": query}) - threadedGntr.close() - - -def sqlChain(query: str, username: str) -> ThreadedGenerator: - shapeAnalytics = ShapeAnalytics(username) - shapeAnalytics.track("sqlChain Invoked", {"Query": query}) - threadedGntr = ThreadedGenerator() - threading.Thread( - target=agent_thread, args=(threadedGntr, query, shapeAnalytics, None) - ).start() - return threadedGntr - - -def slackSqlChain( - query: str, username: str, sendMessage: Say, thread_ts: str, channel: str -) -> ThreadedGenerator: - shapeAnalytics = ShapeAnalytics(username) - shapeAnalytics.track("slackSqlChain Invoked", {"Query": query}) - threadedGntr = ThreadedGenerator() - slackData = SlackData( - username=username, sendMessage=sendMessage, thread_ts=thread_ts, channel=channel - ) - threading.Thread( - target=agent_thread, args=(threadedGntr, query, shapeAnalytics, slackData) - ).start() - slackData.send("ok, this'll take me about 1 or 2 minutes to figure out") - return threadedGntr - - -@functions_framework.http -def runSQLAgent(request): - if request.method == "OPTIONS": - # Allows GET requests from any origin with the Content-Type - # header and caches preflight response for an 3600s - headers = { - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "GET", - "Access-Control-Allow-Headers": "Content-Type", - "Access-Control-Max-Age": "3600", - } - return ("", 204, headers) - # Set CORS headers for the main request - headers = {"Access-Control-Allow-Origin": "*"} - request_args = request.args - query = request_args["query"] - if query: - query = escape(query) - return Response(response=sqlChain(query, "HTTP call"), headers=headers) - else: - return ("Please provide 'query' param", 400, headers) diff --git a/sql_agent/requirements.txt b/sql_agent/requirements.txt deleted file mode 100644 index c3c209a1..00000000 --- a/sql_agent/requirements.txt +++ /dev/null @@ -1,17 +0,0 @@ -functions-framework==3.2.1 -flask==2.1.0 -google-cloud-error-reporting==1.9.1 -langchain==0.0.150 -openai==0.27.2 -mixpanel==4.10.0 -certifi==2022.12.7 -idna==3.4 -python-decouple==3.8 -slack-bolt==1.17.1 -slack-sdk==3.21.1 -urllib3==1.26.15 -snowflake-sqlalchemy==1.4.7 -snowflake-connector-python==3.0.3 -sqlalchemy-bigquery==1.6.1 -sqlalchemy==1.4.47 -matplotlib==3.7.1 \ No newline at end of file diff --git a/sql_agent/shape_analytics.py b/sql_agent/shape_analytics.py deleted file mode 100644 index ad02b0b1..00000000 --- a/sql_agent/shape_analytics.py +++ /dev/null @@ -1,19 +0,0 @@ -from mixpanel import Mixpanel -from decouple import config - - -class ShapeAnalytics: - def __init__(self, distinctId: str): - shape_analytics_api_key = config("SHAPE_ANALYTICS_API_KEY", default=None) - # If customer has not opted into analytics via shape analytics api key, don't initialize Mixpanel - self.mixpanel = ( - None - if shape_analytics_api_key is None - else Mixpanel(shape_analytics_api_key) - ) - self.distinctId = distinctId - - def track(self, event_name: str, properties=None, meta=None) -> None: - if self.mixpanel is None: - return - self.mixpanel.track(self.distinctId, event_name, properties, meta) diff --git a/sql_agent/shape_bot_app.py b/sql_agent/shape_bot_app.py deleted file mode 100644 index a3577507..00000000 --- a/sql_agent/shape_bot_app.py +++ /dev/null @@ -1,45 +0,0 @@ -""" Styleguide: https://google.github.io/styleguide/pyguide.html """ -import os -from slack_bolt import App -from slack_bolt.adapter.socket_mode import SocketModeHandler -from slack_bolt.context.say import Say -from decouple import config -from main import slackSqlChain -from typing import Dict, Any -import logging - -# logging.basicConfig(level=logging.DEBUG) - -app = App(token=config("SLACK_BOT_TOKEN")) - - -@app.event("app_mention") -def message_hello(event: Dict[str, Any], say: Say): - user = app.client.users_profile_get(user=event["user"]) - - thread_ts = event.get("thread_ts", None) - slackSqlChain( - query=event["text"], - username=user["profile"]["real_name"], - sendMessage=say, - thread_ts=thread_ts if thread_ts else event["ts"], - channel=event["channel"], - ) - - -@app.event("message") -def message_hello(event: Dict[str, Any], say: Say): - if event["channel_type"] == "im": - user = app.client.users_profile_get(user=event["user"]) - thread_ts = event.get("thread_ts", None) - slackSqlChain( - query=event["text"], - username=user["profile"]["real_name"], - sendMessage=say, - thread_ts=thread_ts if thread_ts else event["ts"], - channel=event["channel"], - ) - - -if __name__ == "__main__": - SocketModeHandler(app, config("SLACK_APP_TOKEN")).start() diff --git a/sql_agent/shape_sql_callback.py b/sql_agent/shape_sql_callback.py deleted file mode 100644 index fdded55e..00000000 --- a/sql_agent/shape_sql_callback.py +++ /dev/null @@ -1,271 +0,0 @@ -from typing import Any, Dict, List, Optional, Union -from langchain.schema import AgentAction, AgentFinish, LLMResult -from langchain.agents.agent import AgentExecutor -from langchain.agents.agent_toolkits.sql.prompt import SQL_PREFIX, SQL_SUFFIX -from langchain.agents.agent_toolkits.sql.toolkit import SQLDatabaseToolkit -from langchain.agents.mrkl.base import ZeroShotAgent -from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS -from langchain.callbacks.base import BaseCallbackManager -from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler -from langchain.chains.llm import LLMChain -from threaded_generator import ThreadedGenerator -import json -from shape_analytics import ShapeAnalytics -from slack_data import SlackData -import os -from decouple import config -from langchain.llms.openai import OpenAI -from database_factory import * -import matplotlib -import uuid -from slack_sdk import WebClient -from slack_sdk.errors import SlackApiError -from slack_bolt import App - -SHAPE_SQL_PREFIX = """You are an agent designed to interact with a SQL database and write matplotlib code. -Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and write concise and correct matplotlib code to create a chart of the results which saves to an io.BytesIO() buffer called buffer. - -Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. -You can order the results by a relevant column to return the most interesting examples in the database. -Never query for all the columns from a specific table, only ask for the relevant columns given the question. -You have access to tools for interacting with the database. -Only use the below tools. Only use the information returned by the below tools to write your matplotlib code. -You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. - -DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. - -If the question does not seem related to the database, immediately return a JSON string with one key-value pair where the key is '{summary}' and the value is 'I don't know'. - -Before importing pyplot, you should call 'matplotlib.use('Agg')' - -Make sure to close the figure after saving it to the buffer. - -Please return the answer as JSON string where the first key is '{matplotlib_code}' and the value is the matplotlib code you wrote as a string, -and the second key is '{summary}' and the value is one English sentence describing the results of the query. - -""" - -MATPLOTLIB_CODE = "matplotlib_code" -SUMMARY = "summary" -CHARTS_FEATURE_ON = False - - -def create_shape_sql_agent( - callback_manager: Optional[BaseCallbackManager] = None, **kwargs: Any -) -> AgentExecutor: - """Construct a sql agent from an LLM and tools.""" - - db = DatabaseFactory.create_database() - - os.environ["OPENAI_API_KEY"] = config("OPENAI_API_KEY") - llm = OpenAI(temperature=0, model_name="gpt-4") - toolkit = SQLDatabaseToolkit(db=db, llm=llm) - tools = toolkit.get_tools() - prompt = ZeroShotAgent.create_prompt( - tools, - prefix=SHAPE_SQL_PREFIX.format( - dialect=toolkit.dialect, - top_k=10, - matplotlib_code=MATPLOTLIB_CODE, - summary=SUMMARY, - ), - suffix=SQL_SUFFIX, - format_instructions=FORMAT_INSTRUCTIONS, - ) - - llm_chain = LLMChain( - llm=llm, - prompt=prompt, - callback_manager=callback_manager, - ) - tool_names = [tool.name for tool in tools] - agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) - return AgentExecutor.from_agent_and_tools( - agent=agent, - tools=toolkit.get_tools(), - callback_manager=callback_manager, - verbose=True, - max_execution_time=240, - ) - - -class ShapeSQLCallbackHandler(StreamingStdOutCallbackHandler): - """Callback Handler that handles for Shape SQL Events""" - - def __init__( - self, - threadedGntr: ThreadedGenerator, - shapeAnalytics: ShapeAnalytics, - slackData: Optional[SlackData] = None, - ): - super().__init__() - self.threadedGntr = threadedGntr - self.shapeAnalytics = shapeAnalytics - self.slackData = slackData - - def on_agent_action( - self, action: AgentAction, color: Optional[str] = None, **kwargs: Any - ) -> Any: - """Run on agent action.""" - print("on_agent_action") - actionDict = { - "on_agent_action": { - "tool": action.tool, - "tool_input": action.tool_input, - "log": action.log, - } - } - - if self.slackData != None and action.log.partition("Action:")[0] != "": - self.slackData.send(action.log.partition("Action:")[0]) - if action.tool == "query_sql_db": - self.shapeAnalytics.track( - "on_agent_action", - { - "tool": action.tool, - "tool_input": action.tool_input, - }, - ) - self.slackData.send( - f"Here is the SQL Statement I will run: \n`{action.tool_input}`" - ) - self.threadedGntr.send(json.dumps(actionDict)) - - def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - ) -> Any: - """Run when LLM starts running.""" - pass - - def on_llm_new_token(self, token: str, **kwargs: Any) -> Any: - """Run on new LLM token. Only available when streaming is enabled.""" - pass - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any: - """Run when LLM ends running.""" - pass - - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> Any: - """Run when LLM errors.""" - pass - - def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any - ) -> Any: - """Run when chain starts running.""" - print("on_chain_start") - startDict = {"on_chain_start": "Entering new chain..."} - self.threadedGntr.send(json.dumps(startDict)) - - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any: - """Run when chain ends running.""" - print("on_chain_end") - finishDict = {"on_chain_end": "Finished chain"} - self.threadedGntr.send(json.dumps(finishDict)) - - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> Any: - """Run when chain errors.""" - pass - - def on_tool_start( - self, serialized: Dict[str, Any], input_str: str, **kwargs: Any - ) -> Any: - """Run when tool starts running.""" - print("on_tool_start") - pass - - def on_tool_end( - self, - output: str, - color: Optional[str] = None, - observation_prefix: Optional[str] = None, - llm_prefix: Optional[str] = None, - **kwargs: Any, - ) -> None: - """If not the final action, print out observation.""" - print("on_tool_end") - toolEndDict = { - "on_tool_end": { - "observation_prefix": observation_prefix, - "output": output, - "llm_prefix": llm_prefix, - } - } - self.threadedGntr.send(json.dumps(toolEndDict)) - if self.slackData is not None: - self.slackData.send(output.partition("Action:")[0]) - - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> Any: - """Run when tool errors.""" - pass - - def on_text( - self, - text: str, - color: Optional[str] = None, - end: str = "", - **kwargs: Optional[str], - ) -> None: - # """Run when agent ends.""" - print("on_text") - self.threadedGntr(text) - - def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: - """Run on agent end.""" - print("on_agent_finish") - finishResponse = {"log": finish.log, "return_values": finish.return_values} - finishDict = {"on_agent_finish": finishResponse} - self.threadedGntr.send(json.dumps(finishDict)) - - # Send slack message if applicable - if self.slackData is None: - return - - # Parse output - try: - json_string = finish.return_values["output"] - json_dict = json.loads(json_string) - summary = json_dict[SUMMARY] - except Exception as e: - print(e) - return - - if not CHARTS_FEATURE_ON: - self.slackData.send(summary) - return - - # Build chart - try: - code = json_dict[MATPLOTLIB_CODE] - local_vars = {} - exec(code, {}, local_vars) - buffer = local_vars["buffer"] - - except Exception as e: - print(e) - # Just send the english text summary without the chart then. - self.slackData.send(summary) - return - - # Upload chart to Slack - try: - app = App(token=config("SLACK_BOT_TOKEN")) - result = app.client.files_upload_v2( - channel=self.slackData.channel, - thread_ts=self.slackData.thread_ts, - file=buffer.getvalue(), - filename=f"chart_{uuid.uuid4()}.png", - initial_comment=summary, - ) - print(f"SlackResponse is: {result}") - except SlackApiError as e: - print(f"Error uploading file: {e}") - # Just send the english text summary without the chart then. - self.slackData.send(summary) - return diff --git a/sql_agent/slack_data.py b/sql_agent/slack_data.py deleted file mode 100644 index bdb7485e..00000000 --- a/sql_agent/slack_data.py +++ /dev/null @@ -1,12 +0,0 @@ -from slack_bolt.context.say import Say - - -class SlackData: - def __init__(self, username: str, sendMessage: Say, thread_ts: str, channel: str): - self.username = username - self.sendMessage = sendMessage - self.thread_ts = thread_ts - self.channel = channel - - def send(self, message: str) -> None: - self.sendMessage(text=message, thread_ts=self.thread_ts) diff --git a/sql_agent/threaded_generator.py b/sql_agent/threaded_generator.py deleted file mode 100644 index c23e9c4f..00000000 --- a/sql_agent/threaded_generator.py +++ /dev/null @@ -1,21 +0,0 @@ -import queue - - -class ThreadedGenerator: - def __init__(self): - self.queue = queue.Queue() - - def __iter__(self): - return self - - def __next__(self): - item = self.queue.get() - if item is StopIteration: - raise item - return item - - def send(self, data): - self.queue.put(data) - - def close(self): - self.queue.put(StopIteration) diff --git a/sql_agent/unit_tests.py b/sql_agent/unit_tests.py deleted file mode 100644 index f33e99ff..00000000 --- a/sql_agent/unit_tests.py +++ /dev/null @@ -1,106 +0,0 @@ -from typing import Any, List -from langchain.prompts import PromptTemplate -from langchain.llms import OpenAI -from langchain.evaluation.qa import QAEvalChain -from shape_sql_callback import create_shape_sql_agent -from sqlalchemy import * -from sqlalchemy.engine import create_engine -from sqlalchemy.schema import * -from langchain.agents.agent_toolkits import SQLDatabaseToolkit -from langchain.sql_database import SQLDatabase -from langchain.callbacks.base import CallbackManager -from langchain.callbacks.stdout import StdOutCallbackHandler -from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler -from langchain.schema import AgentFinish -from langchain.agents.agent import AgentExecutor -import json -import os - - -class UnitTestHandler(StreamingStdOutCallbackHandler): - def __init__(self, unit_tests: List[dict]): - super().__init__() - self.unit_tests = unit_tests - self.predictions = [] - - def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: - """Run on agent end.""" - print("UnitTestHandler: Agent finished") - print(finish.return_values["output"]) - self.predictions.append({"prediction": finish.return_values["output"]}) - if len(self.predictions) == len(self.unit_tests): - self.evaluate(self.unit_tests, self.predictions) - - def evaluate(self, unit_tests: List[dict], predictions: List[dict]): - TEMPLATE = """You are an expert professor specialized in grading answers to questions. - You are grading the following question: - {query} - Here is the real answer: - {answer} - You are grading the following predicted answer: - {result} - Is the predicted answer correct? Answer Correct or Incorrect. - """ - PROMPT = PromptTemplate( - input_variables=["query", "answer", "result"], template=TEMPLATE - ) - - evalchain = QAEvalChain.from_llm( - llm=OpenAI(model_name="text-davinci-003", temperature=0), prompt=PROMPT - ) - - unit_tests_scores = evalchain.evaluate( - examples=unit_tests, - predictions=predictions, - question_key="question", - answer_key="answer", - prediction_key="prediction", - ) - - correctAnswers = 0 - for i, unit_test in enumerate(unit_tests): - print(f"Test {i+1}:") - print("Question: " + unit_test["question"]) - print("Real Answer: " + unit_test["answer"]) - print("Predicted Answer: " + predictions[i]["prediction"]) - print("Status: " + unit_tests_scores[i]["text"]) - if ( - unit_tests_scores[i]["text"] == "\nCorrect" - ): # QAEvalChain returns a newline character in the answer - correctAnswers += 1 - print() - print(f"{correctAnswers} tests correct out of {len(unit_tests)} tests") - - -def getSqlAgent(unitTestHandler: UnitTestHandler) -> AgentExecutor: - engine = create_engine( - "bigquery://", credentials_info=json.loads(os.environ["BQ_API_KEY"]) - ) - db = SQLDatabase(engine) - toolkit = SQLDatabaseToolkit(db=db) - agent_executor = create_shape_sql_agent( - llm=OpenAI(temperature=0, model_name="text-davinci-003"), - toolkit=toolkit, - callback_manager=CallbackManager([unitTestHandler, StdOutCallbackHandler()]), - verbose=True, - max_execution_time=240, - streaming=True, - ) - return agent_executor - - -def runUnitTests(): - # Add unit tests here - unit_tests = [ - {"question": "How many tables are there?", "answer": "5 tables"}, - {"question": "Which state has the most covid cases?", "answer": "California"}, - {"question": "How many austin bike stations are there?", "answer": "102"}, - ] - unitTestHandler = UnitTestHandler(unit_tests=unit_tests) - sqlAgent = getSqlAgent(unitTestHandler=unitTestHandler) - for unit_test in unit_tests: - question = unit_test["question"] - sqlAgent.run(question) - - -runUnitTests()