-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
181 lines (151 loc) · 5.76 KB
/
app.py
File metadata and controls
181 lines (151 loc) · 5.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import streamlit as st
import pickle
import nltk
import faiss
from nltk.stem import WordNetLemmatizer
import google.generativeai as genai
from tensorflow.keras.models import load_model
from sentence_transformers import SentenceTransformer
import pandas as pd
import os
import datetime
from dotenv import load_dotenv
from download_assets import download_files
download_files()
load_dotenv()
for resource in ["punkt", "wordnet", "omw-1.4"]:
try:
nltk.data.find(f"tokenizers/{resource}")
except LookupError:
nltk.download(resource)
lemmatizer = WordNetLemmatizer()
@st.cache_resource
def load_resources():
model = load_model("chatbot_model.h5")
labels = pickle.load(open("label_encoder.pkl", "rb"))
tfidf = pickle.load(open("tfidf_vectorizer.pkl", "rb"))
return model, labels, tfidf
all_df = []
for main, subfolders, filename in os.walk("data"):
for file in filename:
if file.endswith('.csv'):
df = pd.read_csv(os.path.join(main, file))
all_df.append(df)
data = pd.concat(all_df, ignore_index=True)
qa_questions = []
qa_answers = []
for index, row in data.iterrows():
qa_questions.append(row['Question'])
qa_answers.append(row['Answer'])
@st.cache_resource
def load_embed_model():
# Only loads the model once, keeping it in memory
return SentenceTransformer("all-MiniLM-L6-v2", device="cpu")
@st.cache_resource
def load_faiss_components():
"""Loads the pre-calculated FAISS index and QA answers list."""
try:
# Load the saved index
index = faiss.read_index("faiss_index.bin")
# Load the saved QA Answers list
with open("qa_answers.pkl", "rb") as f:
qa_answers_loaded = pickle.load(f)
return index, qa_answers_loaded
except FileNotFoundError as e:
st.error(f"Required file not found: {e.filename}. Please run 'python create_index.py' first.")
st.stop()
except Exception as e:
st.error(f"Error loading FAISS components: {e}")
st.stop()
# Load the components once at startup
index, qa_answers = load_faiss_components()
embed_model = load_embed_model()
def semantic_search(user_question, top_k=3):
query = load_embed_model().encode([user_question]).astype('float32')
faiss.normalize_L2(query)
distances, indices = index.search(query, k=top_k)
return [qa_answers[i] for i in indices[0]]
api_key = st.secrets["GEMINI_API_KEY"]
st.write("API key loaded:", bool(api_key))
if not api_key:
st.error("Gemini API key not found! Check .env file.")
st.stop()
genai.configure(api_key=api_key)
model = genai.GenerativeModel("gemini-2.5-flash")
def summarize_with_gemini(answer_text):
prompt = (
"Summarize the following medical explanation clearly and concisely:\n\n"
f"{answer_text}"
)
try:
response = model.generate_content(prompt)
return response.text.strip()
except Exception as e:
return f"Summarization failed: {e}"
def truncate_answer(answer, max_sentences=2):
sentences = answer.split(". ")
if len(sentences) <= max_sentences:
return answer
return ". ".join(sentences[:max_sentences]) + "..."
# --- Define the Chat Processing Function ---
def process_user_input():
# 1. Get the current user input from the key
user_input = st.session_state.user_question_key
# 2. Skip if input is empty
if not user_input:
return
# 3. Process the question
if user_input.lower() in ["exit", "quit", "bye"]:
full_response = "Goodbye! Have a great day!"
elif user_input.lower().startswith("search"):
query = user_input[7:]
full_response = f"You can search this on Google: https://www.google.com/search?q={query}"
elif user_input.lower() == "time":
full_response = f"The current time is {datetime.datetime.now().strftime('%H:%M:%S')}."
else:
top_answers = semantic_search(user_input, top_k=1)
if top_answers:
full_response = top_answers[0]
else:
full_response = "I'm not sure how to respond. Can you rephrase?"
# 4. Summarize and Append to History
summary = summarize_with_gemini(full_response)
st.session_state.history.append(("You", user_input))
st.session_state.history.append(("Bot", {"summary": summary, "full": full_response}))
# 5. Clear the input box after submission
st.session_state.user_question_key = ""
st.set_page_config(page_title="Medical Q&A Chatbot", page_icon="💬")
st.title("💬 Medical Q&A Chatbot")
if "history" not in st.session_state:
st.session_state.history = []
if "user_question_key" not in st.session_state:
st.session_state.user_question_key = "" # Initialize the key
# --- Update the Input Widget ---
# Add a key and the callback function
user_input = st.text_input(
"Ask me a medical question:",
key="user_question_key",
on_change=process_user_input
)
for speaker, message in st.session_state.history:
if speaker == "You":
st.markdown(f"**🧑 You:** {message}")
else:
# message is {"summary": ..., "full": ...}
st.markdown(f"**🤖 Summary:** {message['summary']}")
with st.expander("📖 Full Answer"):
st.write(message["full"])
col1, col2 = st.columns(2)
with col1:
if st.button("🧹 Clear Chat"):
st.session_state.history = []
with col2:
if st.button("💾 Save Chat"):
if st.session_state.history:
chat_text = ""
for speaker, msg in st.session_state.history:
chat_text += f"{speaker}: {msg}\n"
filename = f"chat_history_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
st.download_button("📥 Download Chat History", chat_text, file_name=filename)
else:
st.warning("Chat is empty!")