forked from carolinefrasca/llamaindex-build-2023
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstreamlit_app.py
135 lines (114 loc) · 6.26 KB
/
streamlit_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import streamlit as st
from llama_index.llms import OpenAI
import openai
import logging
import sys
import os
import llama_hub
from streamlit_pills import pills
from llama_index.tools.query_engine import QueryEngineTool
from llama_index.objects import ObjectIndex, SimpleToolNodeMapping
# from llama_index.query_engine import ToolRetrieverRouterQueryEngine
from llama_index import (
VectorStoreIndex,
SummaryIndex,
ServiceContext,
StorageContext,
download_loader
)
from llama_index.query_engine.router_query_engine import RouterQueryEngine
from llama_index.selectors.llm_selectors import (
LLMSingleSelector,
LLMMultiSelector,
)
from llama_index.selectors.pydantic_selectors import (
PydanticMultiSelector,
PydanticSingleSelector,
)
st.set_page_config(page_title="Chat with Snowflake's Wikipedia page, powered by LlamaIndex", page_icon="🦙", layout="centered", initial_sidebar_state="auto", menu_items=None)
st.title("Chat with Snowflake's Wikipedia page, powered by LlamaIndex 💬🦙")
st.info("Because this chatbot is powered by **LlamaIndex's [router query engine](https://gpt-index.readthedocs.io/en/latest/examples/query_engine/RouterQueryEngine.html)**, it can answer both **summarization questions** and **context-specific questions** based on the contents of [Snowflake's Wikipedia page](https://en.wikipedia.org/wiki/Snowflake_Inc.).", icon="ℹ️")
openai.api_key = st.secrets.openai_key
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
@st.cache_resource
def load_index_data():
WikipediaReader = download_loader("WikipediaReader",custom_path="local_dir")
loader = WikipediaReader()
documents = loader.load_data(pages=['Snowflake Inc.'])
# initialize service context (set chunk size)
service_context = ServiceContext.from_defaults(chunk_size=1024)
nodes = service_context.node_parser.get_nodes_from_documents(documents)
# initialize storage context (by default it's in-memory)
storage_context = StorageContext.from_defaults()
storage_context.docstore.add_documents(nodes)
summary_index = SummaryIndex(nodes, storage_context=storage_context)
vector_index = VectorStoreIndex(nodes, storage_context=storage_context)
return summary_index, vector_index
summary_index, vector_index = load_index_data()
if "list_query_engine" not in st.session_state.keys(): # Initialize the query engine
st.session_state.list_query_engine = summary_index.as_query_engine(response_mode="tree_summarize",use_async=True,)
if "vector_query_engine" not in st.session_state.keys():
st.session_state.vector_query_engine = vector_index.as_query_engine()
list_tool = QueryEngineTool.from_defaults(
query_engine=st.session_state.list_query_engine,
description=(
"Useful for questions summarizing Snowflake's Wikipedia page"
),
)
vector_tool = QueryEngineTool.from_defaults(
query_engine=st.session_state.vector_query_engine,
description=(
"Useful for retrieving specific information about Snowflake"
),
)
if "router_query_engine" not in st.session_state.keys(): # Initialize the query engine
st.session_state.router_query_engine = RouterQueryEngine(selector=PydanticSingleSelector.from_defaults(), query_engine_tools=[list_tool,vector_tool,],)
selected = pills("Choose a question to get started or write your own below.", ["What is Snowflake?", "What company did Snowflake announce they would acquire in October 2023?", "What company did Snowflake acquire in March 2022?", "When did Snowflake IPO?"], clearable=True, index=None)
if "messages" not in st.session_state.keys(): # Initialize the chat messages history
st.session_state.messages = [
{"role": "assistant", "content": "Ask me a question about Snowflake!"}
]
def add_to_message_history(role, content):
message = {"role": role, "content": str(content)}
st.session_state.messages.append(message) # Add response to message history
for message in st.session_state.messages: # Display the prior chat messages
with st.chat_message(message["role"]):
st.write(message["content"])
query_engines=["list query engine","vector query engine",]
if selected:
with st.chat_message("user"):
st.write(selected)
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
response = st.session_state.router_query_engine.query(selected)
st.write(str(response))
add_to_message_history("user",selected)
add_to_message_history("assistant",response)
selector_dict = response.metadata["selector_result"].dict()
query_engine_index = selector_dict["selections"][0]["index"]
query_engine_used = query_engines[query_engine_index]
reason = selector_dict["selections"][0]["reason"]
if reason[0:4]=="Snow":
explanation = "Used the **" + query_engine_used + "** to answer this question because " + reason
else:
explanation = "Used the **" + query_engine_used + "** to answer this question because " + reason[0:1].lower() + reason[1:]
st.success(explanation,icon="✅")
if prompt := st.chat_input("Your question"): # Prompt for user input and save to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# If last message is not from assistant, generate a new response
if st.session_state.messages[-1]["role"] != "assistant":
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
response = st.session_state.router_query_engine.query(prompt)
st.write(str(response))
add_to_message_history("assistant", response)
selector_dict = response.metadata["selector_result"].dict()
query_engine_index = selector_dict["selections"][0]["index"]
query_engine_used = query_engines[query_engine_index]
reason = selector_dict["selections"][0]["reason"]
if reason[0:4]=="Snow":
explanation = "Used the **" + query_engine_used + "** to answer this question because " + reason
else:
explanation = "Used the **" + query_engine_used + "** to answer this question because " + reason[0:1].lower() + reason[1:]
st.success(explanation,icon="✅")