-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwebui.py
138 lines (105 loc) · 4.94 KB
/
webui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
##########################
# Chatbot User Interface #
##########################
# Python Standard Library Imports
import io # Input/Output operations
import random # Random number generation
import uuid # Universally Unique Identifier generation
from pathlib import Path # Handling file system paths
from typing import NamedTuple # Type hinting for named tuples
# Third-Party Library Imports
import streamlit as st # UI Framework for creating web applications
from langchain_core.language_models.llms import LLM
from langchain_community.llms import LlamaCpp
# Local Imports
from webui_config import UiConfig # Configuration settings for the web UI
from llm_connector import llm_stream_result, LlmGenerationParameters, craft_prompt, craft_result_with_prompt
from document_rag_processor import topk_documents, RagParameters
def main_ui_logic(config: UiConfig, llm_instance: LLM) -> None:
st.title("🎓 LLM Inference Web UI")
### Environment prepare.
document_folder = Path(config.document_folder)
# TODO: Logger: display warning.
document_folder.mkdir(exist_ok=True)
### States
if "messages" not in st.session_state:
st.session_state.messages = []
if "history" not in st.session_state:
st.session_state.history = []
if "documents" not in st.session_state:
st.session_state.documents = []
if "session_id" not in st.session_state:
# Generate user session identifier.
st.session_state.session_id = uuid.uuid4().hex
### Components
with st.sidebar:
st.markdown("# General Settings")
with st.expander("參數說明"):
st.markdown("### LLM Generation Parameter")
st.markdown("Top K: 保留機率最高的前 K 個字")
st.markdown("Top P: 從機率總和為 P 的字中選擇")
st.markdown("Temperature: 生成時的亂度")
st.markdown("Repetition Penalty: 重複字的懲罰")
st.markdown("### LLM Generation Parameter")
model_topk = st.slider("Top K", 0, 200, 10, key="model_topk")
model_topp = st.slider("Top P", 0.0, 1.0, 0.9, key="model_topp")
model_temperature = st.slider("Temperature", 0.0, 1.0, 0.75, key="model_temperature")
model_repetition_penalty = st.slider("Repetition Penalty", 0.0, 2.0, 1.00, key="model_repetition_penalty")
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# React to user input
if user_input := st.chat_input("How can I help you today?"):
# TODO: User model selection.
llm_model_conf = config.llm_models
embedding_conf = config.embedding_model
llm_param = LlmGenerationParameters.new_generation_parameter(
top_k=model_topk,
top_p=model_topp,
temperature=model_temperature,
repetition_penalty=model_repetition_penalty,
max_new_tokens=8192,
)
# Display user message in chat message container
st.chat_message("user").markdown(user_input)
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": user_input})
# Display assistant response in chat message container
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
## RAG
rag_docs = []
# Prompt crafting.
prompt = craft_prompt(user_input, rag_content=[], keep_placeholder=False)
# TODO: Append history to prompt.
#prompt = "".join(st.session_state.history) + prompt
# Simulating bot typing.
for response in llm_stream_result(llm_instance, prompt, llm_param): # type: ignore
cursor = "❖"
full_response += (response or "")
message_placeholder.markdown(full_response + cursor)
# While complete, display full bot response.
with message_placeholder.container():
st.markdown(full_response)
with st.expander("Raw Output"):
st.text_area("Raw Model Output", full_response)
full_response_with_prompt = craft_result_with_prompt(user_input, full_response)
# Add assistant response to chat history
st.session_state.history.append(full_response_with_prompt)
st.session_state.messages.append({"role": "assistant", "content": full_response})
if __name__ == "__main__":
# Load config.
with open("config.yaml", "r", encoding="utf-8") as f:
config = UiConfig.load_config_from_file(f)
# Load LLM model.
if config.llm_models.provider == "llama-cpp":
llm_instance = LlamaCpp(
model_path=config.llm_models.model_path,
verbose=False,
max_tokens=8192,
n_ctx=8192,
)
else: raise NotImplementedError("Might implement sometime lol.")
main_ui_logic(config=config, llm_instance=llm_instance)