forked from THUDM/VisualGLM-6B
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcli_demo.py
118 lines (103 loc) · 4.85 KB
/
cli_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# -*- encoding: utf-8 -*-
import os
import sys
import torch
import argparse
from transformers import AutoTokenizer
from sat.model.mixins import CachedAutoregressiveMixin
from sat.quantization.kernels import quantize
from model import VisualGLMModel, chat
from finetune_visualglm import FineTuneVisualGLMModel
from sat.model import AutoModel
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--max_length", type=int, default=2048, help='max length of the total sequence')
parser.add_argument("--top_p", type=float, default=0.4, help='top p for nucleus sampling')
parser.add_argument("--top_k", type=int, default=100, help='top k for top k sampling')
parser.add_argument("--temperature", type=float, default=.8, help='temperature for sampling')
parser.add_argument("--english", action='store_true', help='only output English')
parser.add_argument("--quant", choices=[8, 4], type=int, default=None, help='quantization bits')
parser.add_argument("--from_pretrained", type=str, default="visualglm-6b", help='pretrained ckpt')
parser.add_argument("--prompt_zh", type=str, default="描述这张图片。", help='Chinese prompt for the first round')
parser.add_argument("--prompt_en", type=str, default="Describe the image.", help='English prompt for the first round')
args = parser.parse_args()
# torch.distributed.init_process_group('nccl', init_method='env://')
# rank = torch.distributed.get_rank()
# print(f"rank = {rank} is initialized")
# # 单机多卡情况下,localrank = rank. 严谨应该是local_rank来设置device
# torch.cuda.set_device(rank)
# load model
model, model_args = AutoModel.from_pretrained(
args.from_pretrained,
args=argparse.Namespace(
fp16=True,
skip_init=True,
num_gpus=2,
use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False,
device='cuda' if (torch.cuda.is_available() and args.quant is None) else 'cpu',
))
# from accelerate import dispatch_model
# from utils import auto_configure_device_map
# if device_map is None:
# device_map = auto_configure_device_map(num_gpus=2)
# model = dispatch_model(model, device_map=device_map)
model = model.eval()
if args.quant:
quantize(model.transformer, args.quant)
if torch.cuda.is_available():
model = model.cuda()
model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
if not args.english:
print('欢迎使用 VisualGLM-6B 模型,输入图像URL或本地路径读图,继续输入内容对话,clear 重新开始,stop 终止程序')
else:
print('Welcome to VisualGLM-6B model. Enter an image URL or local file path to load an image. Continue inputting text to engage in a conversation. Type "clear" to start over, or "stop" to end the program.')
with torch.no_grad():
while True:
history = None
cache_image = None
if not args.english:
image_path = input("请输入图像路径或URL(回车进入纯文本对话): ")
else:
image_path = input("Please enter the image path or URL (press Enter for plain text conversation): ")
if image_path == 'stop':
break
if len(image_path) > 0:
query = args.prompt_en if args.english else args.prompt_zh
else:
if not args.english:
query = input("用户:")
else:
query = input("User: ")
while True:
if query == "clear":
break
if query == "stop":
sys.exit(0)
try:
response, history, cache_image = chat(
image_path,
model,
tokenizer,
query,
history=history,
image=cache_image,
max_length=args.max_length,
top_p=args.top_p,
temperature=args.temperature,
top_k=args.top_k,
english=args.english,
invalid_slices=[slice(63823, 130000)] if args.english else []
)
except Exception as e:
print(e)
break
sep = 'A:' if args.english else '答:'
print("大模型:"+response.split(sep)[-1].strip())
image_path = None
if not args.english:
query = input("用户:")
else:
query = input("User: ")
if __name__ == "__main__":
main()