Skip to content

Commit bde285d

Browse files
lkk12014402rootpre-commit-ci[bot]Spycsh
authored
move examples gateway (opea-project#992)
Co-authored-by: root <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sihan Chen <[email protected]>
1 parent f5c08d4 commit bde285d

File tree

17 files changed

+1236
-113
lines changed

17 files changed

+1236
-113
lines changed

AudioQnA/audioqna.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import asyncio
55
import os
66

7-
from comps import AudioQnAGateway, MicroService, ServiceOrchestrator, ServiceType
7+
from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
8+
from comps.cores.proto.api_protocol import AudioChatCompletionRequest, ChatCompletionResponse
9+
from comps.cores.proto.docarray import LLMParams
10+
from fastapi import Request
811

9-
MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0")
1012
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888))
1113
ASR_SERVICE_HOST_IP = os.getenv("ASR_SERVICE_HOST_IP", "0.0.0.0")
1214
ASR_SERVICE_PORT = int(os.getenv("ASR_SERVICE_PORT", 9099))
@@ -16,7 +18,7 @@
1618
TTS_SERVICE_PORT = int(os.getenv("TTS_SERVICE_PORT", 9088))
1719

1820

19-
class AudioQnAService:
21+
class AudioQnAService(Gateway):
2022
def __init__(self, host="0.0.0.0", port=8000):
2123
self.host = host
2224
self.port = port
@@ -50,9 +52,43 @@ def add_remote_service(self):
5052
self.megaservice.add(asr).add(llm).add(tts)
5153
self.megaservice.flow_to(asr, llm)
5254
self.megaservice.flow_to(llm, tts)
53-
self.gateway = AudioQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port)
55+
56+
async def handle_request(self, request: Request):
57+
data = await request.json()
58+
59+
chat_request = AudioChatCompletionRequest.parse_obj(data)
60+
parameters = LLMParams(
61+
# relatively lower max_tokens for audio conversation
62+
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 128,
63+
top_k=chat_request.top_k if chat_request.top_k else 10,
64+
top_p=chat_request.top_p if chat_request.top_p else 0.95,
65+
temperature=chat_request.temperature if chat_request.temperature else 0.01,
66+
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
67+
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
68+
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
69+
streaming=False, # TODO add streaming LLM output as input to TTS
70+
)
71+
result_dict, runtime_graph = await self.megaservice.schedule(
72+
initial_inputs={"byte_str": chat_request.audio}, llm_parameters=parameters
73+
)
74+
75+
last_node = runtime_graph.all_leaves()[-1]
76+
response = result_dict[last_node]["byte_str"]
77+
78+
return response
79+
80+
def start(self):
81+
super().__init__(
82+
megaservice=self.megaservice,
83+
host=self.host,
84+
port=self.port,
85+
endpoint=str(MegaServiceEndpoint.AUDIO_QNA),
86+
input_datatype=AudioChatCompletionRequest,
87+
output_datatype=ChatCompletionResponse,
88+
)
5489

5590

5691
if __name__ == "__main__":
57-
audioqna = AudioQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT)
92+
audioqna = AudioQnAService(port=MEGA_SERVICE_PORT)
5893
audioqna.add_remote_service()
94+
audioqna.start()

AudioQnA/audioqna_multilang.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
import base64
66
import os
77

8-
from comps import AudioQnAGateway, MicroService, ServiceOrchestrator, ServiceType
8+
from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
9+
from comps.cores.proto.api_protocol import AudioChatCompletionRequest, ChatCompletionResponse
10+
from comps.cores.proto.docarray import LLMParams
11+
from fastapi import Request
912

10-
MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0")
1113
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888))
1214

1315
WHISPER_SERVER_HOST_IP = os.getenv("WHISPER_SERVER_HOST_IP", "0.0.0.0")
@@ -52,7 +54,7 @@ def align_outputs(self, data, cur_node, inputs, runtime_graph, llm_parameters_di
5254
return data
5355

5456

55-
class AudioQnAService:
57+
class AudioQnAService(Gateway):
5658
def __init__(self, host="0.0.0.0", port=8000):
5759
self.host = host
5860
self.port = port
@@ -90,9 +92,43 @@ def add_remote_service(self):
9092
self.megaservice.add(asr).add(llm).add(tts)
9193
self.megaservice.flow_to(asr, llm)
9294
self.megaservice.flow_to(llm, tts)
93-
self.gateway = AudioQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port)
95+
96+
async def handle_request(self, request: Request):
97+
data = await request.json()
98+
99+
chat_request = AudioChatCompletionRequest.parse_obj(data)
100+
parameters = LLMParams(
101+
# relatively lower max_tokens for audio conversation
102+
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 128,
103+
top_k=chat_request.top_k if chat_request.top_k else 10,
104+
top_p=chat_request.top_p if chat_request.top_p else 0.95,
105+
temperature=chat_request.temperature if chat_request.temperature else 0.01,
106+
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
107+
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
108+
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
109+
streaming=False, # TODO add streaming LLM output as input to TTS
110+
)
111+
result_dict, runtime_graph = await self.megaservice.schedule(
112+
initial_inputs={"byte_str": chat_request.audio}, llm_parameters=parameters
113+
)
114+
115+
last_node = runtime_graph.all_leaves()[-1]
116+
response = result_dict[last_node]["byte_str"]
117+
118+
return response
119+
120+
def start(self):
121+
super().__init__(
122+
megaservice=self.megaservice,
123+
host=self.host,
124+
port=self.port,
125+
endpoint=str(MegaServiceEndpoint.AUDIO_QNA),
126+
input_datatype=AudioChatCompletionRequest,
127+
output_datatype=ChatCompletionResponse,
128+
)
94129

95130

96131
if __name__ == "__main__":
97-
audioqna = AudioQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT)
132+
audioqna = AudioQnAService(port=MEGA_SERVICE_PORT)
98133
audioqna.add_remote_service()
134+
audioqna.start()

AvatarChatbot/avatarchatbot.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
import os
66
import sys
77

8-
from comps import AvatarChatbotGateway, MicroService, ServiceOrchestrator, ServiceType
8+
from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
9+
from comps.cores.proto.api_protocol import AudioChatCompletionRequest, ChatCompletionResponse
10+
from comps.cores.proto.docarray import LLMParams
11+
from fastapi import Request
912

10-
MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0")
1113
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888))
1214
ASR_SERVICE_HOST_IP = os.getenv("ASR_SERVICE_HOST_IP", "0.0.0.0")
1315
ASR_SERVICE_PORT = int(os.getenv("ASR_SERVICE_PORT", 9099))
@@ -27,7 +29,7 @@ def check_env_vars(env_var_list):
2729
print("All environment variables are set.")
2830

2931

30-
class AvatarChatbotService:
32+
class AvatarChatbotService(Gateway):
3133
def __init__(self, host="0.0.0.0", port=8000):
3234
self.host = host
3335
self.port = port
@@ -70,7 +72,39 @@ def add_remote_service(self):
7072
self.megaservice.flow_to(asr, llm)
7173
self.megaservice.flow_to(llm, tts)
7274
self.megaservice.flow_to(tts, animation)
73-
self.gateway = AvatarChatbotGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port)
75+
76+
async def handle_request(self, request: Request):
77+
data = await request.json()
78+
79+
chat_request = AudioChatCompletionRequest.model_validate(data)
80+
parameters = LLMParams(
81+
# relatively lower max_tokens for audio conversation
82+
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 128,
83+
top_k=chat_request.top_k if chat_request.top_k else 10,
84+
top_p=chat_request.top_p if chat_request.top_p else 0.95,
85+
temperature=chat_request.temperature if chat_request.temperature else 0.01,
86+
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
87+
streaming=False, # TODO add streaming LLM output as input to TTS
88+
)
89+
# print(parameters)
90+
91+
result_dict, runtime_graph = await self.megaservice.schedule(
92+
initial_inputs={"byte_str": chat_request.audio}, llm_parameters=parameters
93+
)
94+
95+
last_node = runtime_graph.all_leaves()[-1]
96+
response = result_dict[last_node]["video_path"]
97+
return response
98+
99+
def start(self):
100+
super().__init__(
101+
megaservice=self.megaservice,
102+
host=self.host,
103+
port=self.port,
104+
endpoint=str(MegaServiceEndpoint.AVATAR_CHATBOT),
105+
input_datatype=AudioChatCompletionRequest,
106+
output_datatype=ChatCompletionResponse,
107+
)
74108

75109

76110
if __name__ == "__main__":
@@ -89,5 +123,6 @@ def add_remote_service(self):
89123
]
90124
)
91125

92-
avatarchatbot = AvatarChatbotService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT)
126+
avatarchatbot = AvatarChatbotService(port=MEGA_SERVICE_PORT)
93127
avatarchatbot.add_remote_service()
128+
avatarchatbot.start()

ChatQnA/chatqna.py

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,17 @@
66
import os
77
import re
88

9-
from comps import ChatQnAGateway, MicroService, ServiceOrchestrator, ServiceType
9+
from comps import Gateway, MegaServiceEndpoint, MicroService, ServiceOrchestrator, ServiceType
10+
from comps.cores.proto.api_protocol import (
11+
ChatCompletionRequest,
12+
ChatCompletionResponse,
13+
ChatCompletionResponseChoice,
14+
ChatMessage,
15+
UsageInfo,
16+
)
17+
from comps.cores.proto.docarray import LLMParams, RerankerParms, RetrieverParms
18+
from fastapi import Request
19+
from fastapi.responses import StreamingResponse
1020
from langchain_core.prompts import PromptTemplate
1121

1222

@@ -35,7 +45,6 @@ def generate_rag_prompt(question, documents):
3545
return template.format(context=context_str, question=question)
3646

3747

38-
MEGA_SERVICE_HOST_IP = os.getenv("MEGA_SERVICE_HOST_IP", "0.0.0.0")
3948
MEGA_SERVICE_PORT = int(os.getenv("MEGA_SERVICE_PORT", 8888))
4049
GUARDRAIL_SERVICE_HOST_IP = os.getenv("GUARDRAIL_SERVICE_HOST_IP", "0.0.0.0")
4150
GUARDRAIL_SERVICE_PORT = int(os.getenv("GUARDRAIL_SERVICE_PORT", 80))
@@ -178,13 +187,14 @@ def align_generator(self, gen, **kwargs):
178187
yield "data: [DONE]\n\n"
179188

180189

181-
class ChatQnAService:
190+
class ChatQnAService(Gateway):
182191
def __init__(self, host="0.0.0.0", port=8000):
183192
self.host = host
184193
self.port = port
185194
ServiceOrchestrator.align_inputs = align_inputs
186195
ServiceOrchestrator.align_outputs = align_outputs
187196
ServiceOrchestrator.align_generator = align_generator
197+
188198
self.megaservice = ServiceOrchestrator()
189199

190200
def add_remote_service(self):
@@ -228,7 +238,6 @@ def add_remote_service(self):
228238
self.megaservice.flow_to(embedding, retriever)
229239
self.megaservice.flow_to(retriever, rerank)
230240
self.megaservice.flow_to(rerank, llm)
231-
self.gateway = ChatQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port)
232241

233242
def add_remote_service_without_rerank(self):
234243

@@ -261,7 +270,6 @@ def add_remote_service_without_rerank(self):
261270
self.megaservice.add(embedding).add(retriever).add(llm)
262271
self.megaservice.flow_to(embedding, retriever)
263272
self.megaservice.flow_to(retriever, llm)
264-
self.gateway = ChatQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port)
265273

266274
def add_remote_service_with_guardrails(self):
267275
guardrail_in = MicroService(
@@ -319,7 +327,66 @@ def add_remote_service_with_guardrails(self):
319327
self.megaservice.flow_to(retriever, rerank)
320328
self.megaservice.flow_to(rerank, llm)
321329
# self.megaservice.flow_to(llm, guardrail_out)
322-
self.gateway = ChatQnAGateway(megaservice=self.megaservice, host="0.0.0.0", port=self.port)
330+
331+
async def handle_request(self, request: Request):
332+
data = await request.json()
333+
stream_opt = data.get("stream", True)
334+
chat_request = ChatCompletionRequest.parse_obj(data)
335+
prompt = self._handle_message(chat_request.messages)
336+
parameters = LLMParams(
337+
max_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
338+
top_k=chat_request.top_k if chat_request.top_k else 10,
339+
top_p=chat_request.top_p if chat_request.top_p else 0.95,
340+
temperature=chat_request.temperature if chat_request.temperature else 0.01,
341+
frequency_penalty=chat_request.frequency_penalty if chat_request.frequency_penalty else 0.0,
342+
presence_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 0.0,
343+
repetition_penalty=chat_request.repetition_penalty if chat_request.repetition_penalty else 1.03,
344+
streaming=stream_opt,
345+
chat_template=chat_request.chat_template if chat_request.chat_template else None,
346+
)
347+
retriever_parameters = RetrieverParms(
348+
search_type=chat_request.search_type if chat_request.search_type else "similarity",
349+
k=chat_request.k if chat_request.k else 4,
350+
distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None,
351+
fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20,
352+
lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5,
353+
score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2,
354+
)
355+
reranker_parameters = RerankerParms(
356+
top_n=chat_request.top_n if chat_request.top_n else 1,
357+
)
358+
result_dict, runtime_graph = await self.megaservice.schedule(
359+
initial_inputs={"text": prompt},
360+
llm_parameters=parameters,
361+
retriever_parameters=retriever_parameters,
362+
reranker_parameters=reranker_parameters,
363+
)
364+
for node, response in result_dict.items():
365+
if isinstance(response, StreamingResponse):
366+
return response
367+
last_node = runtime_graph.all_leaves()[-1]
368+
response = result_dict[last_node]["text"]
369+
choices = []
370+
usage = UsageInfo()
371+
choices.append(
372+
ChatCompletionResponseChoice(
373+
index=0,
374+
message=ChatMessage(role="assistant", content=response),
375+
finish_reason="stop",
376+
)
377+
)
378+
return ChatCompletionResponse(model="chatqna", choices=choices, usage=usage)
379+
380+
def start(self):
381+
382+
super().__init__(
383+
megaservice=self.megaservice,
384+
host=self.host,
385+
port=self.port,
386+
endpoint=str(MegaServiceEndpoint.CHAT_QNA),
387+
input_datatype=ChatCompletionRequest,
388+
output_datatype=ChatCompletionResponse,
389+
)
323390

324391

325392
if __name__ == "__main__":
@@ -329,10 +396,12 @@ def add_remote_service_with_guardrails(self):
329396

330397
args = parser.parse_args()
331398

332-
chatqna = ChatQnAService(host=MEGA_SERVICE_HOST_IP, port=MEGA_SERVICE_PORT)
399+
chatqna = ChatQnAService(port=MEGA_SERVICE_PORT)
333400
if args.without_rerank:
334401
chatqna.add_remote_service_without_rerank()
335402
elif args.with_guardrails:
336403
chatqna.add_remote_service_with_guardrails()
337404
else:
338405
chatqna.add_remote_service()
406+
407+
chatqna.start()

0 commit comments

Comments
 (0)