Skip to content

Commit 40cea96

Browse files
committed
chat based
1 parent 3784027 commit 40cea96

File tree

2 files changed

+127
-52
lines changed

2 files changed

+127
-52
lines changed
Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
FROM python:3.11
22

33
COPY app.py app.py
4-
RUN pip install requests jupysql solara openai pandas duckdb duckdb-engine matplotlib
4+
COPY user-logo.png user-logo.png
5+
COPY system-logo.png system-logo.png
6+
COPY assistant-logo.png assistant-logo.png
7+
COPY chat.py chat.py
8+
RUN pip install git+https://github.com/neelasha23/jupysql.git@boxplot
9+
RUN pip install requests solara openai pandas duckdb duckdb-engine matplotlib
510

611

712
ENTRYPOINT ["solara", "run", "app.py", "--host=0.0.0.0", "--port=80"]

examples/docker/chat-with-csv-solara/app.py

Lines changed: 121 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,34 @@
22
from sql.run import run
33
from sql.connection import ConnectionManager
44
from sql.magic import SqlMagic, load_ipython_extension
5-
from sql import _current
65
from IPython.core.interactiveshell import InteractiveShell
76

87

98
import uuid
109
import requests
1110
from functools import partial
1211

13-
import openai
14-
from typing import Optional, cast
15-
16-
import duckdb
17-
import pandas as pd
18-
from matplotlib.figure import Figure
19-
import matplotlib.pyplot as plt
20-
2112
import solara
2213
import solara.lab
2314
from solara.components.file_drop import FileDrop
2415

2516
from sql.plot import boxplot, histogram
2617

18+
from chat import *
19+
20+
from matplotlib import pyplot as plt
2721
plt.switch_backend("agg")
2822

23+
css = """
24+
.main {
25+
width: 100%;
26+
height: 100%;
27+
max-width: 1200px;
28+
margin: auto;
29+
padding: 1em;
30+
}
31+
"""
32+
2933
openai.api_key = "YOUR_API_KEY"
3034

3135
prompt_template = """
@@ -41,16 +45,21 @@
4145
And replace NAME with the column name, do not include the table name
4246
"""
4347

48+
4449
def gen_name():
4550
return str(uuid.uuid4())[:8] + '.csv'
4651

52+
4753
def load_data(name):
4854
run.run_statements(conn, "drop table if exists my_data", sqlmagic)
4955
run.run_statements(conn, f"create table my_data as (select * from '{name}')", sqlmagic)
5056
cols = inspect.get_columns("my_data")
5157
return cols
5258

5359

60+
def delete_data():
61+
run.run_statements(conn, "drop table if exists my_data", sqlmagic)
62+
5463
ip = InteractiveShell()
5564

5665
sqlmagic = SqlMagic(shell=ip)
@@ -67,13 +76,14 @@ def load_data(name):
6776
config=sqlmagic,
6877
)
6978

79+
7080
class State:
71-
user_message = solara.reactive("")
7281
package = solara.reactive("Matplotlib")
7382
initial_prompt = solara.reactive("")
7483
sample_data_loaded = solara.reactive(False)
7584
upload_data = solara.reactive(False)
7685
upload_data_error = solara.reactive("")
86+
results = solara.reactive(20)
7787

7888
@staticmethod
7989
def load_sample():
@@ -89,7 +99,6 @@ def load_sample():
8999
State.initial_prompt.value = prompt_template.format(cols)
90100
else:
91101
solara.Warning("Failed to fetch the data. Check the URL and try again.")
92-
93102

94103
@staticmethod
95104
def load_from_file(file):
@@ -105,41 +114,117 @@ def load_from_file(file):
105114
State.upload_data_error.value = str(e)
106115
return
107116
State.upload_data_error.value = ""
108-
109-
110117

111118
@staticmethod
112119
def reset():
113-
State.user_message.value = ""
120+
delete_data()
121+
State.initial_prompt.value = ""
122+
State.sample_data_loaded.value = False
123+
State.upload_data.value = False
124+
State.upload_data_error.value = ""
114125

115126
@staticmethod
116-
def chat_with_gpt3(prompt):
127+
def chat_with_gpt3(prompts):
117128
response = openai.ChatCompletion.create(
118129
model="gpt-3.5-turbo",
119130
messages=[
120131
{"role": "system", "content": State.initial_prompt.value},
121132
{"role": "user", "content": "Show me the first 5 rows"},
122133
{"role": "assistant", "content": "SELECT * FROM my_data LIMIT 5"},
123-
{"role": "user", "content": prompt}
124-
]
134+
] + [{"role": prompt.role, "content": prompt.content} for prompt in prompts],
135+
temperature=0.1,
136+
stream=True
125137
)
126-
return response['choices'][0]['message']['content']
138+
139+
total = ""
140+
for chunk in response:
141+
part = chunk['choices'][0]['delta'].get("content", "")
142+
total += part
143+
yield total
144+
145+
146+
@solara.component
147+
def Chat() -> None:
148+
solara.Style("""
149+
.chat-input {
150+
max-width: 800px;
151+
})
152+
""")
153+
154+
messages, set_messages = solara.use_state([
155+
Message(
156+
role="assistant",
157+
content=f"Welcome. Please post your queries!",
158+
df=None,
159+
fig=None)
160+
]
161+
)
162+
input, set_input = solara.use_state("")
163+
164+
def ask_chatgpt():
165+
_messages = messages + [Message(role="user", content=input, df=None, fig=None)]
166+
user_input = input
167+
set_input("")
168+
set_messages(_messages)
169+
if State.initial_prompt.value:
170+
final = None
171+
for command in State.chat_with_gpt3([Message(role="user", content=user_input, df=None, fig=None)]):
172+
final = command
173+
174+
if final.startswith("%sqlplot"):
175+
_, name, column = final.split(" ")
176+
fig = Figure()
177+
ax = fig.subplots()
178+
179+
fn_map = {"histogram": partial(histogram, bins=50),
180+
"boxplot": boxplot}
181+
182+
fn = fn_map[name]
183+
ax = fn("my_data", column, ax=ax)
184+
set_messages(_messages + [Message(role="assistant", content="", df=None, fig=fig)])
185+
#[Message(role="assistant", content=final, df=None, fig=None)])
186+
else:
187+
messages_list = ';'.join([msg.content for msg in _messages])
188+
query_result = run.run_statements(conn, final, sqlmagic)
189+
set_messages(_messages + [Message(role="assistant", content="", df=query_result, fig=None)])
190+
# [Message(role="assistant", content=f"Setting in else part: {final} {user_input} "
191+
# f"{messages_list}",
192+
# df=None, fig=None)])
193+
else:
194+
set_messages(_messages + [Message(role="assistant", content="Please load some data first!", df=None, fig=None)])
195+
196+
with solara.VBox():
197+
for message in messages:
198+
ChatBox(message)
199+
200+
with solara.Row(justify="center"):
201+
with solara.HBox(align_items="center", classes=["chat-input"]):
202+
rv.Textarea(v_model=input, on_v_model=set_input, solo=True, hide_details=True, outlined=True,
203+
rows=1,
204+
auto_grow=True)
205+
solara.IconButton("send", on_click=ask_chatgpt)
127206

128207

129208
@solara.component
130209
def Page():
131-
user_message = State.user_message.value
132210
package = State.package.value
133211
initial_prompt = State.initial_prompt.value
134212
sample_data_loaded = State.sample_data_loaded.value
135213
upload_data = State.upload_data.value
136214
upload_data_error = State.upload_data_error.value
215+
results =State.results.value
137216

138217
with solara.AppBarTitle():
139-
solara.Text("Data Visualisation App")
218+
solara.Text("Data Querying and Visualisation App")
140219

141220
with solara.Card(title="About", elevation=6, style="background-color: #f5f5f5;"):
142-
solara.Markdown("""This Solara app is designed for automatic data visualizations""")
221+
solara.Markdown("""This Solara app is designed for chatting with your data. <br> <br>
222+
Examples of queries :
223+
unique column-name values ;
224+
select top 20 rows from table ; <br> <br>
225+
Example of queries that will return a plot :
226+
histogram on column ;
227+
boxplot on column""")
143228

144229
with solara.Sidebar():
145230
with solara.Card("Controls", margin=0, elevation=0):
@@ -150,50 +235,35 @@ def Page():
150235
solara.Button("Clear dataset", color="primary", text=True, outlined=True, on_click=State.reset)
151236
FileDrop(on_file=State.load_from_file, on_total_progress=lambda *args: None,
152237
label="Drag a .csv file here")
238+
if initial_prompt:
239+
solara.InputInt("Number of preview rows", value=State.results, continuous_update=True)
153240

154241
solara.Select(label="Visualisation Library", value=State.package, values=['Matplotlib', 'Plotly', 'Altair'])
155242

156243
solara.Markdown("Hosted in [Ploomber Cloud](https://ploomber.io/)")
157244

158245
if sample_data_loaded:
159246
solara.Info("Sample data is loaded")
160-
sql_output = run.run_statements(conn, "select * from my_data limit 5", sqlmagic)
161-
solara.DataFrame(sql_output)
247+
sql_output = run.run_statements(conn, f"select * from my_data limit {results}", sqlmagic)
248+
solara.DataFrame(sql_output, items_per_page=10)
162249

163250
if upload_data:
164251
solara.Info("Data is successfully uploaded")
165-
sql_output = run.run_statements(conn, "select * from my_data limit 5", sqlmagic)
166-
solara.DataFrame(sql_output)
252+
sql_output = run.run_statements(conn, f"select * from my_data limit {results}", sqlmagic)
253+
solara.DataFrame(sql_output, items_per_page=10)
167254

168255
if upload_data_error:
169256
solara.Error(f"Error uploading data: {upload_data_error}")
170-
171-
if user_message:
172-
if initial_prompt == "" :
173-
solara.Info("You must upload data first")
174-
175-
command = State.chat_with_gpt3(user_message)
176-
177-
if command.startswith("%sqlplot"):
178-
_, name, column = command.split(" ")
179-
fig = Figure()
180-
ax = fig.subplots()
181-
182-
fn_map = {"histogram": partial(histogram, bins=50, ax=ax),
183-
"boxplot": partial(boxplot, ax=ax)}
184-
fn = fn_map[name]
185-
ax = fn("my_data", column)
186-
solara.FigureMatplotlib(fig)
187257

188-
else:
189-
sql_output = run.run_statements(conn, command, sqlmagic)
190-
solara.DataFrame(sql_output)
191-
192-
solara.Markdown(f"Command: {command}")
193-
194-
195-
solara.InputText("Enter your query", value=State.user_message)
196-
258+
if initial_prompt == "":
259+
solara.Info("No data loaded")
260+
261+
solara.Style(css)
262+
with solara.VBox(classes=["main"]):
263+
solara.HTML(tag="h3", style="margin: auto;", unsafe_innerHTML="Chat with your data")
264+
265+
Chat()
266+
197267

198268
@solara.component
199269
def Layout(children):

0 commit comments

Comments
 (0)