-
Notifications
You must be signed in to change notification settings - Fork 0
/
view_streamlit.py
336 lines (313 loc) · 14 KB
/
view_streamlit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
import streamlit as st
import os
from streamlit_option_menu import option_menu
import numpy as np
import pandas as pd
from raceplotly.plots import barplot
from collections import deque
import json
import train
import test
import tools.analysis_tools.analyze_logs as alog
import tools.analysis_tools.confusion_matrix as tac
import os
from streamlit_echarts import st_echarts
def is_exit(path):
# 判断目录是否存在
if not os.path.exists(path):
# 如果目录不存在,则创建目录
os.makedirs(path)
print(f"已创建目录:{path}")
else:
print(f"目录已存在:{path}")
def train_model(params):
st.write("开始训练模型...")
st.write("训练参数:", params)
# 在这里执行模型训练的逻辑
model=train.run_model(params)
st.write('Train success!!!!⭐')
# 上传,获取配置文件
def creat_upfile_config():
config_file = st.file_uploader("Upload a configuration file", type=["jpg","png"])
config_path=''
filename=''
if config_file is not None:
print(config_file)
config_filename = config_file.name # 获取上传文件的文件名
filename=config_filename
config_path = f"./uploaded_files/{config_filename}" # 指定保存文件的路径
with open(config_path, "wb") as f:
f.write(config_file.getvalue()) # 保存上传文件到指定路径
return config_path,filename
# 获取指定文件夹下的文件
def list_files_in_directory(directory):
files = []
for file in os.listdir(directory):
file_path = os.path.join(directory, file)
if os.path.isfile(file_path): # 仅添加文件,排除目录
files.append(file)
return files
def creat_page_train():
sbdir = st.sidebar.form("训练参数")
# 创建侧边栏来设置训练参数
sbdir.title("训练参数设置")
# 设置轮数、学习率、批次大小
epochs = sbdir.slider("训练轮数", min_value=1, max_value=10, value=5)
learning_rate = sbdir.slider("学习率", min_value=0.001, max_value=0.02, value=0.02)
batch_size = sbdir.selectbox("批次大小", options=[1, 2, 4, 8,16], index=2)
# 选择配置文件
config_path='./config_model/'+sbdir.selectbox("请选择一个配置文件", list_files_in_directory('./config_model/'))
if sbdir.form_submit_button("测试参数"):
train_params = {
"epochs": epochs,
"learning_rate": learning_rate,
"batch_size": batch_size,
"config": config_path
}
st.write(train_params)
# 创建开始训练按钮
if sbdir.form_submit_button("开始训练"):
train_params = {
"epochs": epochs,
"learning_rate": learning_rate,
"batch_size": batch_size,
"config": config_path
}
train_model(train_params)
def creat_page_test():
sbdir = st.sidebar.form("训练参数")
sbdir.title("训练参数")
# 选择配置文件
config_path='./config_model/'+sbdir.selectbox("请选择一个配置文件", list_files_in_directory('./config_model/'))
epoch_path='./work_dirs/'+sbdir.selectbox("请选择一个权重文件", list_files_in_directory('./work_dirs/'))
img_path,filename=creat_upfile_config()
iou = sbdir.slider("iou", min_value=0.0, max_value=1.0, value=0.2)
pkl_path='./work_dirs/'+sbdir.selectbox("请选择一个pkl文件", list_files_in_directory('./work_dirs/'))
if sbdir.form_submit_button("测试参数"):
train_params = {
"epochs": epoch_path,
"config_path": config_path,
"img_path": img_path,
"filename": filename,
"iou": iou
}
st.write(train_params)
if sbdir.form_submit_button("统计"):
train_params = {
"epochs": epoch_path,
"config_path": config_path,
"img_path": img_path,
"filename": filename,
"iou": iou
}
res_path,options=test.predict(train_params,iou)
col1, col2 = st.columns(2)
# 在第一列显示第一张图片
col1.image(img_path, use_column_width=True)
# 在第二列显示第二张图片
col2.image(res_path, use_column_width=True)
# st.pyplot(fig)
st_echarts(options,height=400)
if sbdir.form_submit_button("混淆矩阵"):
train_params = {
"epochs": epoch_path,
"config_path": config_path,
"img_path": img_path,
"filename": filename,
"iou": iou,
"pkl_path": pkl_path
}
fig=tac.matrix(train_params)
#
st.set_option('deprecation.showPyplotGlobalUse', False)
st.pyplot(fig)
def creat_page_view():
st.markdown(""" <style> .font {
font-size:25px ; font-family: 'Cooper Black'; color: #FF9633;}
</style> """, unsafe_allow_html=True)
st.markdown('<p class="font">上传文件...</p>', unsafe_allow_html=True) #use st.markdown
uploaded_file = st.file_uploader('', type=["json"])
if uploaded_file is not None:
# 读取上传的JSON文件并转换为DataFrame
data_list = []
mmap_list=[]
file_contents = uploaded_file.read()
file_contents = file_contents.decode('utf-8')
lines = file_contents.split('\n')
for line in lines:
if line.strip():
data = json.loads(line)
if list(data.keys())[0]!='lr':
mmap_list.append(data)
else:
data_list.append(data)
df = pd.DataFrame(data_list)
mmap_df=pd.DataFrame(mmap_list)
st.write(df)
st.write(mmap_df)
# 动画
st.write('---')
st.markdown('<p class="font">设置参数...</p>', unsafe_allow_html=True)
column_list=list(df)
column_list = deque(column_list)
column_list.appendleft('-')
df.insert(0, '数值', '数值')
with st.form(key='columns_in_form'):
text_style = '<p style="font-family:sans-serif; color:red; font-size: 15px;">***下面两列是必填项***</p>'
st.markdown(text_style, unsafe_allow_html=True)
col2, col3 = st.columns( [ 1, 1])
# with col1:
item_column='数值'
# st.write('you choose item_column:',item_column)
with col2:
value_column=st.selectbox('应变量:',column_list, index=0, help='希望观察哪个数据的变化')
with col3:
time_column=st.selectbox('自变量:',column_list, index=0, help='由哪个数据引起的变化,即按照什么序列变化')
text_style = '<p style="font-family:sans-serif; color:blue; font-size: 15px;">***微调选项(可选)***</p>'
st.markdown(text_style, unsafe_allow_html=True)
col4, col5, col6 = st.columns( [1, 1, 1])
with col4:
direction=st.selectbox('选择数据变化方向:',['-','横向变化','纵向变化'], index=0, help='默认横向变化' )
if direction=='横向变化'or direction=='-':
orientation='horizontal'
elif direction=='纵向变化':
orientation='vertical'
with col5:
item_label=st.text_input('添加纵轴标签:')
with col6:
value_label=st.text_input('添加横轴标签')
col10, col11, col12 = st.columns( [1, 1, 1])
with col10:
speed=st.slider('动画速度',10,500,100, step=10)
frame_duration=500-speed
with col11:
chart_width=st.slider('表格宽度',500,1000,500, step=20)
with col12:
chart_height=st.slider('表格高度',100,1000,300, step=20)
submitted = st.form_submit_button('提交')
st.write('---')
if submitted:
if value_column=='-'or time_column=='-':
st.warning("请完成两个必填项")
else:
st.markdown('<p class="font">生成图形中... 完成!</p>', unsafe_allow_html=True)
df['time_column'] = pd.to_datetime(df[time_column])
df['value_column'] = df[value_column].astype(float)
raceplot = barplot(df, item_column=item_column, value_column=value_column, time_column=time_column)
fig=raceplot.plot(item_label = item_label, value_label = value_label, time_label = time_column+'s:', frame_duration = frame_duration,orientation=orientation)
fig.update_layout(
# title=chart_title,
autosize=False,
width=chart_width,
height=chart_height,
paper_bgcolor="lightgray",
)
st.plotly_chart(fig, use_container_width=True)
# 2222222222222222222222
st.write('---')
st.markdown('<p class="font">设置参数...</p>', unsafe_allow_html=True)
column_list=list(mmap_df)
column_list = deque(column_list)
column_list.appendleft('-')
mmap_df.insert(0, '数值', '数值')
with st.form(key='columns_in_form2'):
text_style = '<p style="font-family:sans-serif; color:red; font-size: 15px;">***下面两列是必填项***</p>'
st.markdown(text_style, unsafe_allow_html=True)
col2, col3 = st.columns( [ 1, 1])
# with col1:
item_column='数值'
# st.write('you choose item_column:',item_column)
with col2:
value_column=st.selectbox('应变量:',column_list, index=0, help='希望观察哪个数据的变化')
with col3:
time_column=st.selectbox('自变量:',column_list, index=0, help='由哪个数据引起的变化,即按照什么序列变化')
text_style = '<p style="font-family:sans-serif; color:blue; font-size: 15px;">***微调选项(可选)***</p>'
st.markdown(text_style, unsafe_allow_html=True)
col4, col5, col6 = st.columns( [1, 1, 1])
with col4:
direction=st.selectbox('选择数据变化方向:',['-','横向变化','纵向变化'], index=0, help='默认横向变化' )
if direction=='横向变化'or direction=='-':
orientation='horizontal'
elif direction=='纵向变化':
orientation='vertical'
with col5:
item_label=st.text_input('添加纵轴标签:')
with col6:
value_label=st.text_input('添加横轴标签')
col10, col11, col12 = st.columns( [1, 1, 1])
with col10:
speed=st.slider('动画速度',10,500,100, step=10)
frame_duration=500-speed
with col11:
chart_width=st.slider('表格宽度',500,1000,500, step=20)
with col12:
chart_height=st.slider('表格高度',100,1000,300, step=20)
submitted = st.form_submit_button('提交')
st.write('---')
if submitted:
if value_column=='-'or time_column=='-':
st.warning("请完成两个必填项")
else:
st.markdown('<p class="font">生成图形中... 完成!</p>', unsafe_allow_html=True)
mmap_df['time_column'] = pd.to_datetime(mmap_df[time_column])
mmap_df['value_column'] = mmap_df[value_column].astype(float)
raceplot = barplot(mmap_df, item_column=item_column, value_column=value_column, time_column=time_column)
fig=raceplot.plot(item_label = item_label, value_label = value_label, time_label = time_column+'s:', frame_duration = frame_duration,orientation=orientation)
fig.update_layout(
# title=chart_title,
autosize=False,
width=chart_width,
height=chart_height,
paper_bgcolor="lightgray",
)
st.plotly_chart(fig, use_container_width=True)
def creat_page_draw():
sbdir = st.sidebar.form("diy参数画图")
sbdir.title("diy参数画图")
json_path='./json/'+sbdir.selectbox("请选择一个配置文件", list_files_in_directory('./json/'))
with open(json_path, 'r') as f:
first_line = f.readline().strip()
# 解析第一行的 JSON 数据
data = json.loads(first_line)
# 获取键
keys = np.array(list(data.keys()))
keys=sbdir.multiselect("选择纵坐标(可多选)",keys
)
if sbdir.form_submit_button("开始画图"):
json_ppth=[]
json_ppth.append(json_path)
img_path=alog.draw_view(json_ppth,keys)
print(img_path)
st.image(img_path)
def main():
# 测试文件是否存在
is_exit('./config_model/')
is_exit('./work_dirs/')
is_exit('./json/')
is_exit('./uploaded_files/')
# 设置页面标题和描述
with st.sidebar:
choose = option_menu("Main Menu", ["Train", "Test","View", "Draw"],
icons=['house', 'file-slides','app-indicator','person lines fill'],
menu_icon="list", default_index=0,
styles={
"container": {"padding": "5!important", "background-color": "#fafafa"},
"icon": {"color": "orange", "font-size": "25px"},
"nav-link": {"font-size": "16px", "text-align": "left", "margin":"0px", "--hover-color": "#eee"},
"nav-link-selected": {"background-color": "#02ab21"},
}
)
st.title("MMDETECTION VIEW")
st.header("病例细胞实例分割可视化页面")
# navigation = st.sidebar.radio("Navigation", ["Train", "Test"])
if choose == "Train":
creat_page_train()
elif choose == "Test":
creat_page_test()
elif choose=="View":
creat_page_view()
elif choose=="Draw":
creat_page_draw()
# 在这里添加其他的应用程序逻辑
if __name__ == '__main__':
main()