forked from THUDM/VisualGLM-6B
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathweb_demo.py
130 lines (104 loc) · 5.48 KB
/
web_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#!/usr/bin/env python
import gradio as gr
from PIL import Image
import os
import json
from model import is_chinese, get_infer_setting, generate_input, chat
import torch
def generate_text_with_image(input_text, image, history=[], request_data=dict(), is_zh=True):
input_para = {
"max_length": 2048,
"min_length": 50,
"temperature": 0.8,
"top_p": 0.4,
"top_k": 100,
"repetition_penalty": 1.2
}
input_para.update(request_data)
input_data = generate_input(input_text, image, history, input_para, image_is_encoded=False)
input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs']
with torch.no_grad():
answer, history, _ = chat(None, model, tokenizer, input_text, history=history, image=input_image, \
max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \
top_k = gen_kwargs['top_k'], temperature=gen_kwargs['temperature'], english=not is_zh)
return answer
def request_model(input_text, temperature, top_p, image_prompt, result_previous):
result_text = [(ele[0], ele[1]) for ele in result_previous]
for i in range(len(result_text)-1, -1, -1):
if result_text[i][0] == "" or result_text[i][1] == "":
del result_text[i]
print(f"history {result_text}")
is_zh = is_chinese(input_text)
if image_prompt is None:
if is_zh:
result_text.append((input_text, '图片为空!请上传图片并重试。'))
else:
result_text.append((input_text, 'Image empty! Please upload a image and retry.'))
return input_text, result_text
elif input_text == "":
result_text.append((input_text, 'Text empty! Please enter text and retry.'))
return "", result_text
request_para = {"temperature": temperature, "top_p": top_p}
image = Image.open(image_prompt)
try:
answer = generate_text_with_image(input_text, image, result_text.copy(), request_para, is_zh)
except Exception as e:
print(f"error: {e}")
if is_zh:
result_text.append((input_text, '超时!请稍等几分钟再重试。'))
else:
result_text.append((input_text, 'Timeout! Please wait a few minutes and retry.'))
return "", result_text
result_text.append((input_text, answer))
print(result_text)
return "", result_text
DESCRIPTION = '''# 大模型图片自动标注'''
MAINTENANCE_NOTICE1 = '提示1: 如果应用报了“Something went wrong, connection error out”的错误,请关闭代理并重试。\n提示2: 如果你上传了很大的图片,比如10MB大小,那将需要一些时间来上传和处理,请耐心等待。'
MAINTENANCE_NOTICE2 = '提示1: 如果应用报了“Something went wrong, connection error out”的错误,请关闭代理并重试。\n提示2: 如果你上传了很大的图片,比如10MB大小,那将需要一些时间来上传和处理,请耐心等待。'
NOTES = '大模型图片自动标注'
def clear_fn(value):
return "", [("", "你想了解什么?")], None
def clear_fn2(value):
return [("", "你想了解什么?")]
def main(args):
gr.close_all()
global model, tokenizer
model, tokenizer = get_infer_setting(gpu_device=0, quant=args.quant, model_path=args.model)
with gr.Blocks(css='style.css') as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=4.5):
with gr.Group():
input_text = gr.Textbox(label='Input Text', placeholder='Please enter text prompt below and press ENTER.')
with gr.Row():
run_button = gr.Button('Generate')
clear_button = gr.Button('Clear')
image_prompt = gr.Image(type="filepath", label="Image Prompt", value=None)
with gr.Row():
temperature = gr.Slider(maximum=1, value=0.8, minimum=0, label='Temperature')
top_p = gr.Slider(maximum=1, value=0.4, minimum=0, label='Top P')
with gr.Group():
with gr.Row():
maintenance_notice = gr.Markdown(MAINTENANCE_NOTICE1)
with gr.Column(scale=5.5):
result_text = gr.components.Chatbot(label='Multi-round conversation History', value=[("", "Hi, What do you want to know about this image?")]).style(height=550)
gr.Markdown(NOTES)
print(gr.__version__)
run_button.click(fn=request_model,inputs=[input_text, temperature, top_p, image_prompt, result_text],
outputs=[input_text, result_text])
input_text.submit(fn=request_model,inputs=[input_text, temperature, top_p, image_prompt, result_text],
outputs=[input_text, result_text])
clear_button.click(fn=clear_fn, inputs=clear_button, outputs=[input_text, result_text, image_prompt])
image_prompt.upload(fn=clear_fn2, inputs=clear_button, outputs=[result_text])
image_prompt.clear(fn=clear_fn2, inputs=clear_button, outputs=[result_text])
print(gr.__version__)
demo.queue(concurrency_count=10)
demo.launch(share=args.share)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--quant", choices=[8, 4], type=int, default=None)
parser.add_argument("--share", action="store_true")
parser.add_argument("--model", default='visualglm-6b')
args = parser.parse_args()
main(args)