forked from dusty-nv/jetson-containers
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathplugin.py
221 lines (176 loc) · 7.47 KB
/
plugin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
#!/usr/bin/env python3
import logging
import queue
import threading
class Plugin(threading.Thread):
"""
Base class for plugins that process incoming/outgoing data from connections
with other plugins, forming a pipeline or graph. Plugins can run either
single-threaded or in an independent thread that processes data out of a queue.
Frequent categories of plugins:
* sources: text prompts, images/video
* llm_queries, RAG, dynamic LLM calls, image postprocessors
* outputs: print to stdout, save images/video
Parameters:
output_channels (int) -- the number of sets of output connections the plugin has
relay (bool) -- if true, will relay any inputs as outputs after processing
drop_inputs (bool) -- if true, only the most recent input in the queue will be used
threaded (bool) -- if true, will spawn independent thread for processing the queue.
TODO: use queue.task_done() and queue.join() for external synchronization
"""
def __init__(self, output_channels=1, relay=False, drop_inputs=False, threaded=True, **kwargs):
"""
Initialize plugin
"""
super().__init__(daemon=True)
self.relay = relay
self.drop_inputs = drop_inputs
self.threaded = threaded
self.interrupted = False
self.processing = False
self.outputs = [[] for i in range(output_channels)]
self.output_channels = output_channels
if threaded:
self.input_queue = queue.Queue()
self.input_event = threading.Event()
def process(self, input, **kwargs):
"""
Abstract process() function that plugin instances should implement.
Don't call this function externally unless threaded=False, because
otherwise the plugin's internal thread dispatches from the queue.
Plugins should return their output data (or None if there isn't any)
You can also call self.output() directly as opposed to returning it.
kwargs:
sender (Plugin) -- only present if data sent from previous plugin
"""
raise NotImplementedError(f"plugin {type(self)} has not implemented process()")
def add(self, plugin, channel=0, **kwargs):
"""
Connect this plugin with another, as either an input or an output.
By default, this plugin will output to the specified plugin instance.
Parameters:
plugin (Plugin|callable) -- either the plugin to link to, or a callback
mode (str) -- 'input' if this plugin should recieve data from the other
plugin, or 'output' if this plugin should send data to it.
Returns a reference to this plugin instance (self)
"""
from local_llm.plugins import Callback
if not isinstance(plugin, Plugin):
if not callable(plugin):
raise TypeError(f"{type(self)}.add() expects either a Plugin instance or a callable function (was {type(plugin)})")
plugin = Callback(plugin, **kwargs)
self.outputs[channel].append(plugin)
if isinstance(plugin, Callback):
logging.debug(f"connected {type(self).__name__} to {plugin.function.__name__} on channel={channel}") # TODO https://stackoverflow.com/a/25959545
else:
logging.debug(f"connected {type(self).__name__} to {type(plugin).__name__} on channel={channel}")
return self
def find(self, type):
"""
Return the plugin with the specified type by searching for it among
the pipeline graph of inputs and output connections to other plugins.
"""
if isinstance(self, type):
return self
for output_channel in self.outputs:
for output in output_channel:
if isinstance(output, type):
return output
plugin = output.find(type)
if plugin is not None:
return plugin
return None
'''
def __getitem__(self, type):
"""
Subscript indexing [] operator alias for find()
"""
return self.find(type)
'''
def __call__(self, input):
"""
Callable () operator alias for the input() function
"""
self.input(input)
def input(self, input):
"""
Add data to the plugin's processing queue (or if threaded=False, process it now)
TODO: multiple input channels?
"""
if self.threaded:
#self.start() # thread may not be started if plugin only called from a callback
if self.drop_inputs:
self.clear_inputs()
self.input_queue.put(input)
self.input_event.set()
else:
self.dispatch(input)
def output(self, output, channel=0):
"""
Output data to the next plugin(s) on the specified channel (-1 for all channels)
"""
if output is None:
return
if channel >= 0:
for output_plugin in self.outputs[channel]:
output_plugin.input(output)
else:
for output_channel in self.outputs:
for output_plugin in output_channel:
output_plugin.input(output)
def start(self):
"""
Start threads for all plugins in the graph that have threading enabled.
"""
if self.threaded:
if not self.is_alive():
super().start()
for output_channel in self.outputs:
for output in output_channel:
output.start()
return self
def run(self):
"""
@internal processes the queue forever when created with threaded=True
"""
while True:
self.input_event.wait()
self.input_event.clear()
while True:
try:
self.dispatch(self.input_queue.get(block=False))
except queue.Empty:
break
def dispatch(self, input):
"""
Invoke the process() function on incoming data
"""
if self.interrupted:
#logging.debug(f"{type(self)} resetting interrupted=False")
self.interrupted = False
self.processing = True
outputs = self.process(input)
self.processing = False
self.output(outputs)
if self.relay:
self.output(input)
def interrupt(self, clear_inputs=True, block=True):
"""
Interrupt any ongoing/pending processing, and optionally clear the input queue.
If block is true, this function will wait until any ongoing processing has finished.
This is done so that any lingering outputs don't cascade downstream in the pipeline.
"""
if clear_inputs:
self.clear_inputs()
self.interrupted = True
while block and self.processing:
continue # TODO use an event for this?
def clear_inputs(self):
"""
Clear the input queue, dropping any data.
"""
while True:
try:
self.input_queue.get(block=False)
except queue.Empty:
return