Skip to content

Commit 75fbf09

Browse files
Feature Add queue (#25)
* Update README.md * Add queueing * Move default properties to properties.py * Reorganize * Make worker cnt configurable * Update README.md
1 parent 3a7826e commit 75fbf09

File tree

8 files changed

+156
-84
lines changed

8 files changed

+156
-84
lines changed

README.md

+5-3
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Microservice that generates subtitles for [TUM-Live](https://live.rbg.tum.de).
3030
```bash
3131
$ grpcurl -plaintext localhost:50055 list live.voice.v1.SubtitleGenerator
3232

33-
voice.SubtitleGenerator.Generate
33+
live.voice.v1.SubtitleGenerator.Generate
3434
```
3535

3636
```bash
@@ -88,7 +88,8 @@ VOSK_MODEL_DIR=/data
8888
VOSK_DWNLD_URLS=https://alphacephei.com/vosk/models/vosk-model-small-en-us-0.15.zip,https://alphacephei.com/vosk/models/vosk-model-small-de-0.15.zip
8989
VOSK_MODELS=model-fr:fr,model-en:en
9090
WHISPER_MODEL=medium
91-
MAX_WORKERS=10
91+
MAX_THREADS=10
92+
CNT_WORKERS=3
9293
```
9394
</p>
9495
</details>
@@ -115,7 +116,8 @@ vosk:
115116
lang: 'de'
116117
whisper:
117118
model: 'tiny'
118-
max_workers: 10
119+
max_threads: 12
120+
cnt_workers: 3
119121
```
120122
</p>
121123
</details>

config.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@ vosk:
1616
lang: 'de'
1717
whisper:
1818
model: 'tiny'
19-
max_workers: 12
19+
max_threads: 12
20+
cnt_workers: 3

subtitles/client.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""Implements gRPC Client facade"""
2+
3+
import logging
4+
import grpc
5+
import subtitles_pb2_grpc, subtitles_pb2
6+
from grpc._channel import _InactiveRpcError
7+
8+
9+
def receive(receiver: str, req: subtitles_pb2.ReceiveRequest):
10+
with grpc.insecure_channel(receiver) as channel:
11+
stub = subtitles_pb2_grpc.SubtitleReceiverStub(channel)
12+
try:
13+
stub.Receive(req)
14+
except _InactiveRpcError as grpc_err:
15+
logging.error(grpc_err.details())
16+
except Exception as err:
17+
logging.error(err)

subtitles/properties.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,20 @@
33
import os.path
44
import yaml
55

6+
DEFAULT_PROPERTIES = {
7+
'api': {'port': 50055},
8+
'receiver': {'host': 'localhost', 'port': '50053'},
9+
'transcriber': 'whisper',
10+
'vosk': {
11+
'model_dir': '/tmp',
12+
'download_urls': [],
13+
'models': []
14+
},
15+
'whisper': {'model': 'tiny'},
16+
'max_threads': None,
17+
'cnt_workers': 1,
18+
}
19+
620

721
class PropertyError(Exception):
822
pass
@@ -77,9 +91,13 @@ def get(self) -> dict:
7791

7892
properties['whisper']['model'] = os.getenv('WHISPER_MODEL', properties['whisper']['model'])
7993

80-
max_workers = os.getenv('MAX_WORKERS', properties['max_workers'])
81-
if max_workers:
82-
properties['max_workers'] = int(max_workers)
94+
max_threads = os.getenv('MAX_THREADS', properties['max_threads'])
95+
if max_threads:
96+
properties['max_threads'] = int(max_threads)
97+
98+
cnt_workers = os.getenv('CNT_WORKERS', properties['cnt_workers'])
99+
if cnt_workers:
100+
properties['cnt_workers'] = int(cnt_workers)
83101

84102
return properties
85103

subtitles/subtitles.py

+44-77
Original file line numberDiff line numberDiff line change
@@ -3,37 +3,31 @@
33
import os
44
from signal import signal, SIGTERM, SIGINT, SIGQUIT, strsignal
55
from concurrent.futures import ThreadPoolExecutor
6-
from properties import YAMLPropertiesFile, EnvProperties, PropertyError
6+
from properties import YAMLPropertiesFile, EnvProperties, PropertyError, DEFAULT_PROPERTIES
77
from grpc_reflection.v1alpha import reflection
8-
from grpc._channel import _InactiveRpcError
98
from google.protobuf import empty_pb2
109
from model_loader import download_models, ModelLoadError
1110
import grpc
1211
import subtitles_pb2
1312
import subtitles_pb2_grpc
13+
from taskqueue import TaskQueue
1414
from vosk_transcriber import VoskTranscriber
1515
from whisper_transcriber import WhisperTranscriber
1616
from transcriber import Transcriber
17+
from tasks import GenerationTask
18+
from worker import Worker
1719

1820

1921
class SubtitleServerService(subtitles_pb2_grpc.SubtitleGeneratorServicer):
2022
"""grpc service for subtitles"""
2123

22-
def __init__(self, transcriber: Transcriber, receiver: str, executor: ThreadPoolExecutor) -> None:
23-
"""Initialize service.
24-
25-
Args:
26-
transcriber: The transcriber used for subtitle generation.
27-
receiver: The address of the receiver service.
28-
executor: Threadpool for jobs.
29-
"""
30-
self.__transcriber = transcriber
31-
self.__receiver = receiver
32-
self.__executor = executor
24+
def __init__(self, queue: TaskQueue) -> None:
25+
"""Initialize service"""
26+
self.__queue = queue
3327

3428
def Generate(self, req: subtitles_pb2.GenerateRequest,
3529
context: grpc.ServicerContext) -> empty_pb2.Empty:
36-
""" Handler function for an incoming Generate request.
30+
"""Handler function for an incoming Generate request.
3731
3832
Args:
3933
req: An object holding the grpc message data.
@@ -50,71 +44,50 @@ def Generate(self, req: subtitles_pb2.GenerateRequest,
5044
logging.debug(f'checking if {source} exists')
5145
if not os.path.isfile(source):
5246
context.abort(grpc.StatusCode.NOT_FOUND, f'can not find source file: {source}')
53-
return empty_pb2.Empty()
54-
55-
logging.debug('starting thread to generate subtitles')
56-
self.__executor.submit(self.__generate, self.__transcriber, source, stream_id, language)
57-
return empty_pb2.Empty()
47+
return
5848

59-
def __generate(self, transcriber: Transcriber, source: str, stream_id: str, language: str) -> None:
60-
subtitles, language = transcriber.generate(source, language)
49+
logging.debug('enqueue request')
50+
self.__queue.put(GenerationTask(source, language, stream_id))
6151

62-
logging.info(f'trying to connect to receiver @ {self.__receiver}')
63-
with grpc.insecure_channel(self.__receiver) as channel:
64-
stub = subtitles_pb2_grpc.SubtitleReceiverStub(channel)
65-
request = subtitles_pb2.ReceiveRequest(
66-
stream_id=stream_id,
67-
subtitles=subtitles,
68-
language=language)
69-
70-
try:
71-
stub.Receive(request)
72-
logging.info('subtitle-request sent')
73-
except _InactiveRpcError as grpc_err:
74-
logging.error(grpc_err.details())
75-
except Exception as err:
76-
logging.error(err)
52+
return empty_pb2.Empty()
7753

7854

79-
def serve(transcriber: Transcriber,
80-
receiver: str,
55+
def serve(executor: ThreadPoolExecutor,
56+
q: TaskQueue,
8157
port: int,
82-
max_workers: int,
8358
debug: bool = False) -> None:
8459
"""Starts the grpc server.
8560
8661
Args:
87-
transcriber: The transcriber used.
88-
receiver: The network address of the receiver.
62+
executor: The pool of threads
63+
q: Queue of tasks
8964
port: The port on which the voice service listens.
90-
max_workers: The maximum number of threads that can be used to execute the given calls.
9165
debug: Whether the server should be started in debug mode or not.
9266
"""
67+
server = grpc.server(executor)
68+
subtitles_pb2_grpc.add_SubtitleGeneratorServicer_to_server(
69+
servicer=SubtitleServerService(q),
70+
server=server)
9371

94-
with ThreadPoolExecutor(max_workers) as executor:
95-
server = grpc.server(executor)
96-
subtitles_pb2_grpc.add_SubtitleGeneratorServicer_to_server(
97-
servicer=SubtitleServerService(transcriber, receiver, executor),
98-
server=server)
72+
if debug:
73+
activate_reflection(server)
9974

100-
if debug:
101-
activate_reflection(server)
75+
logging.info(f'listening at :{port}')
76+
server.add_insecure_port(f'[::]:{port}')
77+
server.start()
10278

103-
logging.info(f'listening at :{port}')
104-
server.add_insecure_port(f'[::]:{port}')
105-
server.start()
79+
def handle_shutdown(signum, *_):
80+
logging.info(f'received "{strsignal(signum)}" signal')
81+
all_requests_done = server.stop(16)
82+
all_requests_done.wait(16)
83+
q.stop()
84+
executor.shutdown(wait=True)
85+
logging.info('shut down gracefully')
10686

107-
def handle_shutdown(signum, *_):
108-
logging.info(f'received "{strsignal(signum)}" signal')
109-
all_requests_done = server.stop(30)
110-
all_requests_done.wait(30)
111-
executor.shutdown(wait=True)
112-
logging.info('shut down gracefully')
113-
114-
signal(SIGTERM, handle_shutdown)
115-
signal(SIGINT, handle_shutdown)
116-
signal(SIGQUIT, handle_shutdown)
117-
server.wait_for_termination()
87+
signal(SIGTERM, handle_shutdown)
88+
signal(SIGINT, handle_shutdown)
89+
signal(SIGQUIT, handle_shutdown)
90+
server.wait_for_termination()
11891

11992

12093
def get_transcriber(properties: dict, debug: bool) -> Transcriber:
@@ -141,18 +114,7 @@ def main():
141114
debug = os.getenv('DEBUG', '') != ""
142115
logging.basicConfig(level=(logging.INFO, logging.DEBUG)[debug])
143116

144-
properties = {
145-
'api': {'port': 50055},
146-
'receiver': {'host': 'localhost', 'port': '50053'},
147-
'transcriber': 'whisper',
148-
'vosk': {
149-
'model_dir': '/tmp',
150-
'download_urls': [],
151-
'models': []
152-
},
153-
'whisper': {'model': 'tiny'},
154-
'max_workers': None,
155-
}
117+
properties = DEFAULT_PROPERTIES
156118

157119
try:
158120
config_file = os.getenv("CONFIG_FILE")
@@ -177,10 +139,15 @@ def main():
177139
transcriber = get_transcriber(properties, debug)
178140
receiver = f'{properties["receiver"]["host"]}:{properties["receiver"]["port"]}'
179141
port = properties['api']['port']
180-
max_workers = properties['max_workers']
142+
max_threads = properties['max_threads']
143+
cnt_workers = properties['cnt_workers']
181144

182145
logging.debug(properties)
183-
serve(transcriber, receiver, port, max_workers, debug)
146+
147+
q = TaskQueue(cnt_workers)
148+
with ThreadPoolExecutor(max_threads) as executor:
149+
[Worker(transcriber, receiver, executor, q) for _ in range(cnt_workers)]
150+
serve(executor, q, port, debug)
184151

185152

186153
if __name__ == "__main__":

subtitles/taskqueue.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import queue
2+
3+
from tasks import StopTask
4+
5+
6+
class TaskQueue(queue.Queue):
7+
def __init__(self, task_worker_cnt: int = 1):
8+
super().__init__()
9+
self.task_worker_cnt = task_worker_cnt
10+
11+
def stop(self):
12+
for _ in range(self.task_worker_cnt):
13+
self.put(StopTask())

subtitles/tasks.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from dataclasses import dataclass
2+
3+
4+
@dataclass
5+
class GenerationTask:
6+
source: str
7+
language: str
8+
stream_id: str
9+
10+
11+
class StopTask:
12+
pass

subtitles/worker.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import logging
2+
from concurrent.futures import ThreadPoolExecutor
3+
import subtitles_pb2
4+
from tasks import GenerationTask, StopTask
5+
from taskqueue import TaskQueue
6+
from transcriber import Transcriber
7+
from client import receive
8+
9+
10+
class Worker:
11+
"""Thread for subtitle generation"""
12+
13+
def __init__(self,
14+
transcriber: Transcriber,
15+
receiver: str,
16+
executor: ThreadPoolExecutor,
17+
taskqueue: TaskQueue):
18+
"""Start the generator threads."""
19+
executor.submit(run, transcriber, receiver, taskqueue)
20+
21+
22+
def run(transcriber: Transcriber, receiver: str, taskqueue: TaskQueue) -> None:
23+
while True:
24+
logging.info('worker: waiting for task...')
25+
task = taskqueue.get()
26+
if isinstance(task, GenerationTask):
27+
logging.info('worker: starting to generate subtitles...')
28+
logging.debug(f'worker: task: {task}')
29+
generate(transcriber, receiver, task)
30+
elif isinstance(task, StopTask):
31+
break
32+
33+
34+
def generate(transcriber: Transcriber, receiver: str, task: GenerationTask) -> None:
35+
subtitles, language = transcriber.generate(task.source, task.language)
36+
37+
logging.info(f'worker: sending receive message to receiver @ {receiver}')
38+
receive(receiver,
39+
req=subtitles_pb2.ReceiveRequest(
40+
stream_id=task.stream_id,
41+
subtitles=subtitles,
42+
language=language))

0 commit comments

Comments
 (0)