-
Notifications
You must be signed in to change notification settings - Fork 6
/
utils.py
72 lines (63 loc) · 1.99 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from collections import deque
import json
from threading import Thread
from model import Predictor
import time
class SLInference:
"""
Main prediction thread.
Attributes:
running (bool): Flag to control the running of the thread.
config (dict): Configuration parameters for the model.
model (Predictor): The prediction model.
input_queue (deque): A queue to hold the input data.
pred (str): The prediction result.
thread (Thread): The worker thread.
"""
def __init__(self, config_path):
"""
Initialize the SLInference object.
Args:
config_path (str): Path to the configuration file.
"""
self.running = True
self.config = self.read_config(config_path)
self.model = Predictor(self.config)
self.input_queue = deque(maxlen=self.config["window_size"])
self.pred = ""
def read_config(self, config_path):
"""
Read the configuration file.
Args:
config_path (str): Path to the configuration file.
Returns:
dict: The configuration parameters.
"""
with open(config_path, "r") as f:
config = json.load(f)
return config
def worker(self):
"""
The main worker function that runs in a separate thread.
"""
while self.running:
if len(self.input_queue) == self.config["window_size"]:
pred_dict = self.model.predict(self.input_queue)
if pred_dict:
self.pred = pred_dict["labels"][0]
self.input_queue.clear()
else:
self.pred = ""
time.sleep(0.1)
def start(self):
"""
Start the worker thread.
"""
self.thread = Thread(target=self.worker)
self.thread.start()
def stop(self):
"""
Stop the worker thread.
"""
self.running = False
self.thread.join()