forked from dusty-nv/jetson-containers
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path__main__.py
131 lines (106 loc) · 3.81 KB
/
__main__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#!/usr/bin/env python3
import os
import sys
import time
import signal
import logging
import numpy as np
from termcolor import cprint
from local_llm import LocalLM, ChatHistory, ChatTemplates
from local_llm.utils import ImageExtensions, ArgParser, load_prompts, print_table
# see utils/args.py for options
parser = ArgParser()
parser.add_argument("--no-streaming", action="store_true", help="wait to output entire reply instead of token by token")
args = parser.parse_args()
prompts = load_prompts(args.prompt)
# load model
model = LocalLM.from_pretrained(
args.model,
quant=args.quant,
api=args.api,
vision_model=args.vision_model
)
# create the chat history
chat_history = ChatHistory(model, args.chat_template, args.system_prompt)
# make an interrupt handler for muting the bot output
last_interrupt = 0.0
interrupt_chat = False
def on_interrupt(signum, frame):
"""
Ctrl+C handler - if done once, interrupts the LLM
If done twice in succession, exits the program
"""
global last_interrupt
global interrupt_chat
curr_time = time.perf_counter()
time_diff = curr_time - last_interrupt
last_interrupt = curr_time
if time_diff > 2.0:
logging.warning("Ctrl+C: interrupting chatbot")
interrupt_chat = True
else:
while True:
logging.warning("Ctrl+C: exiting...")
sys.exit(0)
time.sleep(0.5)
signal.signal(signal.SIGINT, on_interrupt)
while True:
# get the next prompt from the list, or from the user interactivey
if isinstance(prompts, list):
if len(prompts) > 0:
user_prompt = prompts.pop(0)
cprint(f'>> PROMPT: {user_prompt}', 'blue')
else:
break
else:
cprint('>> PROMPT: ', 'blue', end='', flush=True)
user_prompt = sys.stdin.readline().strip()
print('')
# special commands: load prompts from file
# 'reset' or 'clear' resets the chat history
if user_prompt.lower().endswith(('.txt', '.json')):
user_prompt = ' '.join(load_prompts(user_prompt))
elif user_prompt.lower() == 'reset' or user_prompt.lower() == 'clear':
logging.info("resetting chat history")
chat_history.reset()
continue
# add the latest user prompt to the chat history
entry = chat_history.append(role='user', msg=user_prompt)
# images should be followed by text prompts
if 'image' in entry and 'text' not in entry:
logging.debug("image message, waiting for user prompt")
continue
# get the latest embeddings from the chat
embedding, position = chat_history.embed_chat()
if logging.getLogger().isEnabledFor(logging.DEBUG):
logging.debug(f"adding embedding shape={embedding.shape} position={position}")
# generate bot reply
reply = model.generate(
embedding,
streaming=not args.no_streaming,
kv_cache=chat_history.kv_cache,
max_new_tokens=args.max_new_tokens,
min_new_tokens=args.min_new_tokens,
do_sample=args.do_sample,
repetition_penalty=args.repetition_penalty,
temperature=args.temperature,
top_p=args.top_p,
)
bot_reply = chat_history.append(role='bot', text='') # placeholder
if args.no_streaming:
bot_reply.text = reply
cprint(reply, 'green')
else:
for token in reply:
bot_reply.text += token
cprint(token, 'green', end='', flush=True)
if interrupt_chat:
reply.stop()
interrupt_chat = False
break
print('\n')
print_table(model.stats)
print('')
chat_history.kv_cache = reply.kv_cache # save the kv_cache
bot_reply.text = reply.output_text # sync the text once more
#logging.warning('exiting...')