forked from dusty-nv/jetson-containers
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathstream.py
57 lines (43 loc) · 2 KB
/
stream.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#!/usr/bin/env python3
import threading
class StreamingResponse():
"""
Asynchronous output token iterator returned from a model's generate() function.
Use it to stream the reply from the LLM as they are decoded token-by-token:
```
response = model.generate("Once upon a time,")
for token in response:
print(token, end='', flush=True)
```
To terminate processing prematurely, call the .stop() function, which will stop the model
from generating additional output tokens. Otherwise tokens will continue to be filled.
"""
def __init__(self, model, input, **kwargs):
super().__init__()
self.model = model
self.input = input
self.event = threading.Event()
self.kwargs = kwargs
self.kv_cache = kwargs.get('kv_cache', None)
self.stopping = False # set if the user requested early termination
self.stopped = False # set when generation has actually stopped
self.output_tokens = [] # accumulated output tokens so far
self.output_text = '' # detokenized output text so far
def __iter__(self):
return self
def __next__(self):
if self.stopped:
if len(self.output_tokens) == 0 or self.output_tokens[-1] != self.model.tokenizer.eos_token_id:
self.output_tokens.append(self.model.tokenizer.eos_token_id) # add EOS if necessary
return self.get_message_delta()
raise StopIteration
self.event.wait()
self.event.clear()
return self.get_message_delta()
def stop(self):
self.stopping = True
def get_message_delta(self):
message = self.model.tokenizer.decode(self.output_tokens, skip_special_tokens=False) #, clean_up_tokenization_spaces=None
delta = message[len(self.output_text):]
self.output_text = message
return delta